RLHF项目2:营销话术助手¶
学习目标:
1.了解项目背景和项目架构
2.完成RLHF+Qwen的营销话术助手项目代码开发
3.通过UI界面进行数据采集,并完成强化学习模型的训练
4.1 项目背景¶
传统的营销话术系统往往采用固定的规则或模板,难以根据用户的实时反馈动态调整策略,导致用户体验不佳和转化率低下。 本项目旨在构建一个智能的营销话术助手,通过结合深度强化学习和大语言模型技术,实现能够根据对话上下文和用户反馈动态调整营销策略的智能对话系统。系统采用DQN(Deep Q-Network)强化学习算法作为决策核心,通过与用户的真实对话交互收集反馈数据,不断优化营销话术策略。
这种将强化学习应用于营销对话系统的解决方案,不仅提升了对话系统的智能化水平,也为企业提供了更加精准和个性化的营销策略制定工具,具有重要的实用价值和广阔的市场前景。
4.2 项目结构概览¶
Marketing_Script_Assistant/
├── environment/ # 强化学习环境
│ └── dialogue_env.py
├── agents/ # DQN Agent 实现
│ └── dqn_agent.py # dqn智能体
│ └── offline_dqn.py # 离线DQN模型
├── llm/ # 大模型接口
│ └── qwen_api.py # 阿里云 Qwen 接口
├── utils/ # 工具函数
│ ├── state_utils.py # 构建对话状态向量
│ ├── sentence_encoder.py # 对输入文本进行编码,返回句向量
│ └── data_logger.py # 记录训练数据
├── models/ # 存放训练好的模型
│ └── dqn_marketing_model.zip # 保存训练好的dqn模型
│ └── paraphrase-multilingual-MiniLM-L12-v2 # 预训练的Embedding模型
├── data/ # 数据文件
│ └── dialogue_logs.json
├── train.py # 模型训练脚本
└── ui/ # Streamlit 前端界面
└── app.py
4.3 项目开发¶
4.3.1 定义环境¶
- 文件位置:Marketing_Script_Assistant/environment/dialogue_env.py
环境的动作空间中,定义10种营销策略动作,并给出了参考示例。
状态是连续的,动作是离散的。
# -*- coding: utf-8 -*-
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class MarketingDialogueEnv(gym.Env):
"""
营销对话环境类,继承自gym.Env
用于训练智能客服代理选择合适的营销策略进行回复
"""
def __init__(self, mode='offline'):
"""
初始化营销对话环境
:param mode: 运行模式,默认为'offline'离线模式
"""
super(MarketingDialogueEnv, self).__init__()
self.mode = mode # 默认离线模式,使用人工反馈收集的数据进行强化学习训练
# 定义10种营销策略动作,每个动作包含策略描述和示例
self.actions = {
0: ('本次回复使用:情感共鸣策略回答问题,目的是触发客户情感需求(家庭、信任、归属感)',
'例如:“这款安全座椅,像妈妈的手一样守护宝宝每一次出行。”'),
1: ('本次回复使用:价值重构策略回答问题,目的是将价格转化为长期投资,弱化成本感',
'例如:“这不仅是软件,是帮您每月节省7500元成本的工具。”'),
2: ('本次回复使用:社会认同策略回答问题,目的是利用从众心理和权威背书',
'例如:“您行业的XX公司都在用,上周刚续约。”'),
3: ('本次回复使用:故事叙述策略回答问题,目的是用真实案例引发共情,替代说教',
'例如:“一位妈妈用这款净水器后,孩子再没闹过肚子。”'),
4: ('本次回复使用:权威专业策略回答问题,目的是引用认证/数据提升可信度',
'例如:“产品通过欧盟安全认证,测试报告显示故障率低于0.1%。”'),
5: ('本次回复使用:紧迫感策略回答问题,目的是制造稀缺性促立即行动',
'例如:“本周签约可免费升级,仅剩3个名额。”'),
6: ('本次回复使用:差异化优势策略回答问题,目的是对比竞品突出独特价值',
'例如:“同类产品只保修1年,我们提供3年+24小时响应。”'),
7: ('本次回复使用:免费/试用策略回答问题,目的是降低决策门槛,转移风险',
'例如:“先试用7天,无效全额退款。”'),
8: ('本次回复使用:个性化定制策略回答问题,目的是针对需求提供专属方案',
'例如:“根据贵公司物流需求,我们设计了分仓备货方案。”'),
9: ('本次回复使用:异议转化策略回答问题,目的是将反对点转化为卖点',
'例如:“您说价格高,正因我们用了航天级材料,寿命延长5倍。”')
}
# action_space 表示智能体(Agent)在环境中可以采取的所有动作的集合。
# 本项目中定义了10个营销策略,每个策略对应一个动作。
self.action_space = spaces.Discrete(len(self.actions))
# observation_space 描述了 Agent 所能"看到"的环境状态(即输入给模型的状态)。
# 在本项目中,状态是通过对话历史中的所有文本(包括用户和机器人的发言)进行编码并聚合得到的。
# 词嵌入模型是paraphrase-multilingual-MiniLM-L12-v2,维度是384
# 因此,使用的是 Box 空间类型,表示连续值状态向量,最小值-np.inf,最大值为-np.inf,长度为384。
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(384,), dtype=np.float32) # 示例:简单状态表示
self.state = None
self._reset_state()
def _reset_state(self):
"""
重置环境状态为随机向量
"""
self.state = np.random.rand(384).astype(np.float32)
def reset(self, seed=None, options=None):
"""
重置环境到初始状态
:param seed: 随机种子
:param options: 其他选项参数
:return: 初始状态和额外信息字典
"""
super().reset(seed=seed)
return np.zeros(384, dtype=np.float32), {}
def step(self, action):
"""
执行一个动作并返回环境的新状态
:param action: 要执行的动作索引
:return: 新状态、奖励、是否结束、是否截断、额外信息
"""
# 使用已有数据做 offline RL,不会进入step函数
if self.mode == "offline":
# 本案例是离线训练,正常情况下不会调用step函数
# raise NotImplementedError("离线训练模式下不应调用 step()")
# 返回 dummy 数据,防止离线模式,意外调用step函数,提高健壮性
return np.zeros(384, dtype=np.float32), 0.0, True, False, {}
4.3.2 定义强化学习agent¶
- 文件位置:Marketing_Script_Assistant/agents/dqn_agent.py
"""
本项目是RLHF任务,通过人工反馈进行学习。为了更接近实际使用场景,选择off-policy来进行训练。
算法层面,选用DQN算法, 它是一种off-policy算法,使用 replay buffer 存储历史经验,
支持手动扩展 replay buffer,通过 replay_buffer.add()支持外部数据,
"""
from agents.offline_dqn import OfflineDQN
from stable_baselines3.common.buffers import ReplayBuffer
import os
import numpy as np
project_root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
MODEL_PATH = os.path.join(project_root, "models", "dqn_marketing_model")
class DQNAgent:
def __init__(self, env):
self.env = env
if os.path.exists(f"{MODEL_PATH}.zip"):
self.model = OfflineDQN.load(MODEL_PATH)
print("✅ 加载已有模型")
else:
self.model = OfflineDQN("MlpPolicy",
self.env,
verbose=1,
buffer_size=100000,
exploration_fraction=0.3, # 增加探索时间比例
exploration_initial_eps=1.0, # 初始探索率
exploration_final_eps=0.01, # 最终探索率(更小的值)
learning_starts=1000, # 增加开始学习前的步数
target_update_interval=500 # 更新目标网络的频率
)
self.model.exploration_rate = 1.0
print("使用新模型初始化,并强制开启 100% 随机探索")
print("🆕 使用新模型初始化")
def train(self, total_timesteps, dataset=None):
if dataset is not None:
# # 将 dataset 转换为 replay buffer 支持的格式
# observations = []
# actions = []
# rewards = []
# next_observations = []
# dones = []
buffer_size = len(dataset)
self.model.replay_buffer = ReplayBuffer(
buffer_size,
self.model.observation_space,
self.model.action_space,
device=self.model.device,
n_envs=1
)
for item in dataset:
# 每次处理一条 transition
obs = np.array(item["state"], dtype=np.float32)
next_obs = np.array(item["next_state"], dtype=np.float32)
action = np.array([item["action"]], dtype=np.int8) # shape: (1,)
reward = np.array(item["reward"], dtype=np.float32)
done = np.array(item["done"], dtype=bool)
# ✅ 一条一条地添加进 buffer
self.model.replay_buffer.add(
obs=obs,
next_obs=next_obs,
action=action,
reward=reward,
done=done,
infos=[{}]
)
print(f"✅ 成功向 replay buffer 添加 {len(dataset)} 条数据")
# 开始训练
self.model.learn(total_timesteps=total_timesteps)
def predict(self, observation):
# deterministic=False 表示允许探索
action, _states = self.model.predict(observation, deterministic=False)
return action.item()
def save(self):
self.model.save(MODEL_PATH)
print(f"💾 模型已保存至 {MODEL_PATH}.zip")
if __name__ == '__main__':
from environment.dialogue_env import MarketingDialogueEnv
env = MarketingDialogueEnv()
# 创建DQN智能体实例
agent = DQNAgent(env)
action = agent.predict(np.zeros(384))
print(action)
- 文件位置:Marketing_Script_Assistant/agents/offline_dqn.py
使用 DQN 算法进行 fine-tune,自定义一个OfflineDQN类,继承DQN类,重写learn方法,这样在训练的时候,就会使用UI对话界面交互的历史数据一次性的进行离线训练,而不必实时交互进行在线训练了。
from stable_baselines3 import DQN
class OfflineDQN(DQN):
"""自定义DQN类,完全绕过环境交互"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 禁用所有与环境交互相关的设置
self.learning_starts = 0
# self.train_freq = (1, "step")
self.train_freq = 1
self.gradient_steps = 1
self.replay_buffer = None # 稍后我们会设置它
def learn(self, total_timesteps: int, **kwargs) -> "OfflineDQN":
"""完全离线学习,不与环境交互"""
# 1. 准备训练
self._setup_learn(total_timesteps)
# 2. 自定义训练循环
for step in range(total_timesteps):
# 只从回放缓冲区学习,不与环境交互
self.train(gradient_steps=self.gradient_steps, batch_size=self.batch_size)
# 定期更新目标网络
if step % self.target_update_interval == 0:
self._on_step()
# 日志记录(可选)
if step % 100 == 0:
self.logger.record("train/step", step)
return self
4.3.3 开发UI对话界面¶
-
开发UI对话界面
- 文件位置:Marketing_Script_Assistant/ui/app.py
# -*- coding: utf-8 -*-
import sys
import os
import numpy as np
import streamlit as st
# 手动添加项目根目录到 sys.path,以便导入其他模块
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, PROJECT_ROOT)
# 导入自定义模块
from environment.dialogue_env import MarketingDialogueEnv # 自定义对话环境
from agents.dqn_agent import DQNAgent # 强化学习智能体(DQN算法)
from utils.state_utils import build_state # 构建状态向量工具函数
from utils.data_logger import log_transition # 记录训练数据日志工具
from llm.qwen_api import call_qwen # 调用大模型接口
st.set_page_config(page_title="💬 营销话术智能助手 - 基于 RLHF + Qwen", layout="wide")
# 初始化 session_state 变量
# 初始化用户在文本框中输入的问题
if "new_input" not in st.session_state:
st.session_state.new_input = ''
# 初始化强化学习环境
if "env" not in st.session_state:
st.session_state.env = MarketingDialogueEnv()
# 初始化智能体
if "agent" not in st.session_state:
st.session_state.agent = DQNAgent(st.session_state.env)
# 初始化对话历史
if "dialogue_history" not in st.session_state:
st.session_state.dialogue_history = []
# 初始化对话轮数
if "turn_count" not in st.session_state:
st.session_state.turn_count = 0
# 初始化是否结束对话标识,如果为True,则标识为对话结束
if "done" not in st.session_state:
st.session_state.done = False
# 初始化当前状态
if "current_state" not in st.session_state:
st.session_state.current_state = None
# 初始化营销策略,对应的是强化学习中的动作(action)
if "action" not in st.session_state:
st.session_state.action = None
# 获取用户反馈
if "feed_back" not in st.session_state:
st.session_state.feed_back = None
# 初始化用户反馈对应的奖励分数,对应的是强化学习中的奖励(reward)
if "reward" not in st.session_state:
st.session_state.reward = None
# 初始化下一步状态
if "next_state" not in st.session_state:
st.session_state.next_state = None
# 初始化是否提交表单标识,用来控制前端页面表单的渲染
if "form_submitted" not in st.session_state:
st.session_state.form_submitted = False
# 定义用户反馈字典,用户不同的反馈选项,对应的奖励分数不同
user_feedback_dict = {"用户成交或明确表示签约/购买": 5.0,
"用户表现出明显兴趣,如主动提问、索要报价": 3.0,
"用户要求发送产品资料、方案、合同等": 2.5,
"用户对产品或服务细节进行提问": 2.0,
"用户提出异议(如价格贵、没需求等)但仍在沟通": 0.5,
"用户有购买意向但表示价格有压力": 1.0,
"用户未表现明确态度,如说“再看看”": 0.0,
"用户礼貌拒绝,如“先不考虑,谢谢”": -1.0,
"用户明确表示没兴趣或说“不需要”": -2.0,
"用户长时间不回复、已读不回、敷衍应答": -2.5,
"用户强烈抗拒或直接中断对话(如挂电话、拉黑)": -3.0
}
# 侧边栏,设置一个开启对话按钮
with st.sidebar:
if st.button("新开启对话"):
st.session_state.dialogue_history = []
st.session_state.turn_count = 0
st.session_state.done = False
st.session_state.current_state = None
st.session_state.action = None
st.session_state.reward = None
st.session_state.next_state = None
st.title("💬 营销话术智能助手 - 基于 RLHF + Qwen")
# 展示历史对话记录
for msg in st.session_state.dialogue_history:
st.markdown(f"**👤 用户:** {msg['user']}")
st.markdown(f"**🧠 策略:** {msg['strategy']}")
st.markdown(f"**🤖 系统:** {msg['bot']}")
if msg.get("feedback"):
st.markdown(f"**👤 用户反馈:** {msg['feedback']}")
st.markdown("---")
# 接收用户输入
st.session_state.new_input = st.text_input("请输入您的问题:", key=f"user_input_{st.session_state.turn_count}")
if st.session_state.new_input:
# 当前的状态,由历史会话+最新输入生成的
current_state = build_state(st.session_state.dialogue_history + [{"user": st.session_state.new_input, "bot": ""}])
# 使用DQN模型预测,预测的动作就代表的是本案例中的营销策略(营销策略是离散的,共计十种)
action = st.session_state.agent.predict(current_state)
st.session_state.action = action
# 提取营销策略对应的策略描述以及参考样例
strategy = st.session_state.env.actions[action][0]
example = st.session_state.env.actions[action][1]
# 拼接提示词
prompt = f"你是一个专业的销售人员,{strategy},用户的问题是:{st.session_state.new_input},{example},注意事项:不要在回答中出现策略描述"
# 调用大模型生成回复
response = call_qwen(prompt)
# response = prompt # 测试用,直接返回拼接提示词
# 将当前状态保存到session_state
st.session_state.current_state = current_state
# 将回复轮次+1
st.session_state.turn_count += 1
# 添加记录
new_msg = {
"user": st.session_state.new_input,
"strategy": strategy,
"bot": response
}
st.session_state.dialogue_history.append(new_msg)
# 这里的会话显示,目的是方便用户看到大模型回复以后,再决定给出哪个反馈
st.markdown(f"**👤 用户:** {new_msg['user']}")
st.markdown(f"**🧠 策略:** {new_msg['strategy']}")
st.markdown(f"**🤖 系统:** {new_msg['bot']}")
st.markdown(f"**👤 用户反馈:** {''}")
st.markdown("---")
st.session_state.form_submitted = False
# 如果用户未提交表单,则显示表单
if not st.session_state.form_submitted:
st.markdown("### 📢 请给出人“工反馈结果:")
# 创建表单
with st.form(key="feedback_form", enter_to_submit=False):
# 单选框选项
options = list(user_feedback_dict.keys())
# 创建单选框
feedback_action = st.radio(
"请给出人工反馈结果:",
options=options,
index=None, # 不设置默认值
key=f"feedback_action_{len(st.session_state.dialogue_history)}" # 动态 Key 避免冲突
)
# 确认按钮(点击后提交表单)
st.session_state.form_submitted = st.form_submit_button("确认提交")
# 如果已提交表单
if st.session_state.form_submitted:
# 获取人工反馈结果
st.session_state.feedback = feedback_action
# 获取反馈结果对应的奖励分数
st.session_state.reward = user_feedback_dict.get(feedback_action)
# 历史会话添加反馈结果
st.session_state.dialogue_history[-1]["feedback"] = feedback_action
# 出现明确签约/购买,或者表现出明确拒绝,则判定为结束会话。done对应强化学习中的是否结束一个episode
st.session_state.done = feedback_action in ("用户成交或明确表示签约/购买",
"用户明确表示没兴趣或说“不需要”",
"用户长时间不回复、已读不回、敷衍应答",
"用户强烈抗拒或直接中断对话(如挂电话、拉黑)"
)
# 构建下一个状态,包含本轮对话的历史会话生成下一个状态
st.session_state.next_state = build_state(st.session_state.dialogue_history)
# 确保 current_state 是一个有效的 numpy.ndarray
if st.session_state.current_state is None:
st.session_state.current_state = np.array([]) # 或其他合适的初始状态
# 记录状态转移
log_transition(
st.session_state.current_state.tolist(),
st.session_state.action,
st.session_state.reward,
st.session_state.next_state.tolist(),
st.session_state.done,
)
st.rerun()
- 开发大模型生成话术
- 文件位置:Marketing_Script_Assistant/llm/qwen_api.py
为保证生成效果,这里选择了api 的方式调用。
本地大模型调用的方式也提供了,详见下方的注释。
# 在线api调用大模型
import os
import dashscope
from dashscope import Generation
# 从环境变量中获取 API_KEY
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
if dashscope.api_key is None:
raise ValueError("环境变量 DASHSCOPE_API_KEY 未设置")
def call_qwen(prompt: str) -> str:
response = Generation.call(
model="qwen-turbo",
prompt=prompt
)
return response.output.text
# 本地大模型调用
# from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# # 替换为 DeepSeek-R1-Distill-Qwen-1.5B 本地模型路径
# MODEL_PATH = "../models/DeepSeek-R1-Distill-Qwen-1_5b"
#
# # 加载 tokenizer 和模型
# tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True)
#
# # 创建文本生成 pipeline
# qwen_pipeline = pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# device=0 # 使用 GPU(若无 GPU,请删除或设为 -1)
# )
#
#
# def call_qwen(prompt: str) -> str:
# # 使用 DeepSeek-R1-Distill-Qwen-1.5B 模型生成回答
# response = qwen_pipeline(
# prompt,
# max_new_tokens=150, # 推荐方式
# truncation=True,
# num_return_sequences=1
# )
# return response[0]['generated_text']
#
#
4.3.4 收集交互数据¶
-
文本嵌入
- 文件位置:Marketing_Script_Assistant/utils/sentence_encoder.py
import os
from sentence_transformers import SentenceTransformer
import numpy as np
class SentenceEncoder:
def __init__(self, model_name='paraphrase-multilingual-MiniLM-L12-v2'):
project_root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
self.model_path = os.path.join(project_root, "models", model_name)
self.model = SentenceTransformer(self.model_path)
def encode(self, text: str) -> np.ndarray:
"""
对输入文本进行编码,返回句向量
"""
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding
-
构建会话的状态标识(编码)
- 文件位置:Marketing_Script_Assistant/utils/state_utils.py
"""
对话状态处理工具模块
"""
from typing import List, Dict
import numpy as np
from utils.sentence_encoder import SentenceEncoder
encoder = SentenceEncoder()
def build_state(dialogue_history: List[Dict[str, str]]) -> np.ndarray:
"""
构建当前对话状态表示(使用 编码)
参数:
dialogue_history (List[Dict]): 对话历史列表,每个元素包含 "user" 和 "bot"
返回:
np.ndarray: 表示当前状态的向量(384维)
"""
if not dialogue_history:
return np.zeros(384) # 根据所选嵌入模型维度调整,paraphrase-multilingual-MiniLM-L12-v2维度是384
all_utterances = []
for turn in dialogue_history:
all_utterances.append(turn.get("user", ""))
all_utterances.append(turn.get("bot", ""))
embeddings = encoder.model.encode(all_utterances, convert_to_numpy=True)
state = np.mean(embeddings, axis=0) # 取平均向量作为状态
return state.astype(np.float32)
def normalize_state(state: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(state)
return state / norm if norm != 0 else state
-
保存交互数据
- 数据保存到:Marketing_Script_Assistant/data/dialogue_logs.json
dialogue_logs.json文件存储的是一个 List[Dict],每个字典包含以下字段:
字段名 类型 说明
state np.ndarray 当前状态向量(如 384 维)
action int Agent 选择的动作编号
reward float 用户反馈带来的奖励值
next_state np.ndarray 执行动作后的下一个状态
done bool 是否结束本轮对话
- 文件位置:Marketing_Script_Assistant/utils/data_logger.py
import json
import os
from typing import Dict, Any, List
project_root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
DATA_PATH = os.path.join(project_root, "data", "dialogue_logs.json")
def ensure_data_file():
"""确保数据文件存在,若不存在则创建空文件"""
if not os.path.exists(os.path.dirname(DATA_PATH)):
os.makedirs(os.path.dirname(DATA_PATH))
if not os.path.exists(DATA_PATH):
with open(DATA_PATH, 'w', encoding='utf-8') as f:
json.dump([], f)
def load_dialogue_logs() -> List[Dict[str, Any]]:
"""加载现有交互数据"""
ensure_data_file()
try:
with open(DATA_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
except (json.JSONDecodeError, FileNotFoundError, IOError):
print("⚠️ 交互文件损坏或为空,将创建新文件")
return []
def save_dialogue_logs(data: List[Dict[str, Any]]):
"""保存数据到 JSON 文件"""
with open(DATA_PATH, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def log_transition(state: List[float], action: int, reward: float, next_state: List[float], done: bool):
"""
记录一次交互过程(state -> action -> reward -> next_state)
参数:
state (List[float]): 当前状态(建议转换为 list)
action (int): 动作编号
reward (float): 奖励值
next_state (List[float]): 下一状态
done (bool): 是否结束对话
"""
transition = {
"state": state,
"action": action,
"reward": reward,
"next_state": next_state,
"done": done
}
data = load_dialogue_logs()
data.append(transition)
save_dialogue_logs(data)
print(f"✅ 已记录一条交互数据:action={action}, reward={reward}, done={done}")
4.3.5 模型训练¶
- 文件位置:Marketing_Script_Assistant/train.py
from environment.dialogue_env import MarketingDialogueEnv # 导入营销对话环境
from agents.dqn_agent import DQNAgent # 导入DQN智能体
import json
def load_training_data(data_path="./data/dialogue_logs.json"):
"""
加载训练数据
Args:
data_path (str): 训练数据文件路径,默认为"./data/dialogue_logs.json"
Returns:
list: 包含对话日志的列表,如果加载失败则返回空列表
"""
try:
with open(data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"✅ 加载训练数据共 {len(data)} 条")
return data
except Exception as e:
print(f"❌ 加载数据失败:{e}")
return []
def run_training():
"""
运行训练流程
初始化营销对话环境和DQN智能体,加载训练数据并开始训练
"""
# 创建营销对话环境实例
env = MarketingDialogueEnv()
# 创建DQN智能体实例
agent = DQNAgent(env)
# 加载训练数据
training_data = load_training_data()
# 检查是否有训练数据
if len(training_data) == 0:
print(" 无训练数据,请先运行UI界面收集对话数据")
return
print(" 开始训练……")
# 开始训练,训练步数为数据条数的10倍
agent.train(total_timesteps=len(training_data) * 10, dataset=training_data)
# 保存训练好的模型
agent.save()
# 程序入口点
if __name__ == "__main__":
run_training()
4.3.6 模型应用¶
- 文件位置:Marketing_Script_Assistant/agents/dqn_agent.py
在实例化DQNAgent的时候,会进行模型文件的检查,如果有模型文件,则直接加载已有模型;
如果没有,则实例化OfflineDQN,使用新模型初始化。
class DQNAgent:
def __init__(self, env):
self.env = env
if os.path.exists(f"{MODEL_PATH}.zip"):
self.model = OfflineDQN.load(MODEL_PATH)
print("✅ 加载已有模型")
else:
self.model = OfflineDQN("MlpPolicy",
self.env,
verbose=1,
buffer_size=100000,
exploration_fraction=0.3, # 增加探索时间比例
exploration_initial_eps=1.0, # 初始探索率
exploration_final_eps=0.01, # 最终探索率(更小的值)
learning_starts=1000, # 增加开始学习前的步数
target_update_interval=500 # 更新目标网络的频率
)
print("🆕 使用新模型初始化")
4.4 环境准备与依赖安装¶
pip install streamlit gym stable-baselines3 scikit-learn sentence-transformers dashscope transformers torch accelerate
4.5 使用流程详解¶
✅ 步骤一:部署或调用大模型¶
使用阿里云 DashScope API(推荐) 1、注册 阿里云 2、获取 API Key 3、替换 llm/qwen_api.py 中的密钥,或者在电脑的环境变量中设置秘钥
dashscope.api_key = "YOUR_API_KEY_HERE"
✅ 步骤二:启动 UI 对话界面¶
运行 Streamlit 应用:
streamlit run ui/app.py
✅ 步骤三:收集训练数据¶
在 UI 界面中进行多轮对话后,dialogue_logs.json文件将被自动生成,并记录以下信息:
| 字段名 | 类型 | 说明 |
|---|---|---|
| state | np.ndarray | 当前状态向量(如 384 维) |
| action | int | Agent 选择的动作编号 |
| reward | float | 用户反馈带来的奖励值 |
| next_state | np.ndarray | 执行动作后的下一个状态 |
| done | bool | 是否结束本轮对话 |
✅ 步骤四:训练 DQN 模型¶
运行训练脚本:
python train.py
✅ 步骤五:加载训练好的模型进行推理¶
重启 Streamlit 应用时,如果检测到 models/dqn_marketing_model.zip 存在,则会自动加载模型并使用最新策略进行决策。