聊聊Spring AI Alibaba的RedisChatMemory
序
本文主要研究一下Spring AI Alibaba的RedisChatMemory
RedisChatMemory
community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/RedisChatMemory.java
代码语言:javascript代码运行次数:0运行复制public class RedisChatMemory implements ChatMemory, AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);
private static final String DEFAULT_KEY_PREFIX = "spring_ai_alibaba_chat_memory";
private static final String DEFAULT_HOST = "127.0.0.1";
private static final int DEFAULT_PORT = 6379;
private static final String DEFAULT_PASSWORD = null;
private final JedisPool jedisPool;
private final Jedis jedis;
private final ObjectMapper objectMapper;
public RedisChatMemory() {
this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);
}
public RedisChatMemory(String host, int port, String password) {
JedisPoolConfig poolConfig = new JedisPoolConfig();
this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);
this.jedis = jedisPool.getResource();
this.objectMapper = new ObjectMapper();
SimpleModule module = new SimpleModule();
module.addDeserializer(Message.class, new MessageDeserializer());
this.objectMapper.registerModule(module);
logger.info("Connected to Redis at {}:{}", host, port);
}
@Override
public void add(String conversationId, List<Message> messages) {
String key = DEFAULT_KEY_PREFIX + conversationId;
for (Message message : messages) {
try {
String messageJson = objectMapper.writeValueAsString(message);
jedis.rpush(key, messageJson);
}
catch (JsonProcessingException e) {
throw new RuntimeException("Error serializing message", e);
}
}
logger.info("Added messages to conversationId: {}", conversationId);
}
@Override
public List<Message> get(String conversationId, int lastN) {
String key = DEFAULT_KEY_PREFIX + conversationId;
List<String> messageStrings = jedis.lrange(key, -lastN, -1);
List<Message> messages = new ArrayList<>();
for (String messageString : messageStrings) {
try {
Message message = objectMapper.readValue(messageString, Message.class);
messages.add(message);
}
catch (JsonProcessingException e) {
throw new RuntimeException("Error deserializing message", e);
}
}
logger.info("Retrieved {} messages for conversationId: {}", messages.size(), conversationId);
return messages;
}
@Override
public void clear(String conversationId) {
String key = DEFAULT_KEY_PREFIX + conversationId;
jedis.del(key);
logger.info("Cleared messages for conversationId: {}", conversationId);
}
@Override
public void close() {
if (jedis != null) {
jedis.close();
logger.info("Redis connection closed.");
}
if (jedisPool != null) {
jedisPool.close();
logger.info("Jedis pool closed.");
}
}
public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {
try {
String key = DEFAULT_KEY_PREFIX + conversationId;
List<String> all = jedis.lrange(key, 0, -1);
if (all.size() >= maxLimit) {
all = all.stream().skip(Math.max(0, deleteSize)).toList();
}
this.clear(conversationId);
for (String message : all) {
jedis.rpush(key, message);
}
}
catch (Exception e) {
logger.error("Error clearing messages from Redis chat memory", e);
throw new RuntimeException(e);
}
}
public void updateMessageById(String conversationId, String messages) {
String key = "spring_ai_alibaba_chat_memory:" + conversationId;
try {
this.jedis.del(key);
this.jedis.rpush(key, new String[] { messages });
}
catch (Exception var6) {
logger.error("Error updating messages from Redis chat memory", var6);
throw new RuntimeException(var6);
}
}
}
RedisChatMemory的构造器初始化了JedisPool并给ObjectMapper注册了
org.springframework.ai.chat.messages.Message
类型的MessageDeserializer;其add方法遍历messages挨个序列化为json然后rpush到spring_ai_alibaba_chat_memory{conversationId}
中;其get方法通过lrange取出最近n条记录,再反序列化为message对象;其clear方法直接删除该key;close方法则先关闭jedis再关闭jedisPool
MessageDeserializer
community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/serializer/MessageDeserializer.java
代码语言:javascript代码运行次数:0运行复制public class MessageDeserializer extends JsonDeserializer<Message> {
private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class);
public Message deserialize(JsonParser p, DeserializationContext ctxt) {
ObjectMapper mapper = (ObjectMapper) p.getCodec();
JsonNode node = null;
Message message = null;
try {
node = mapper.readTree(p);
String messageType = node.get("messageType").asText();
switch (messageType) {
case "USER" -> message = new UserMessage(node.get("text").asText(),
mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {
}), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {
}));
case "ASSISTANT" -> message = new AssistantMessage(node.get("text").asText(),
mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {
}), (List<AssistantMessage.ToolCall>) mapper.convertValue(node.get("toolCalls"),
new TypeReference<Collection<AssistantMessage.ToolCall>>() {
}),
(List<Media>) mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {
}));
default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
}
;
}
catch (IOException e) {
logger.error("Error deserializing message", e);
}
return message;
}
}
MessageDeserializer继承了JsonDeserializer,它读取messageType字段,然后对于USER类型创建UserMessage、对于ASSISTANT类型创建AssistantMessage
小结
spring-ai-alibaba-redis-memory提供了ChatMemory的redis实现,它通过jedis使用rpush添加message,通过lrange取出最近N条,通过del删除指定会话的消息。
doc
- java2ai