最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

聊聊Spring AI Alibaba的RedisChatMemory

网站源码admin2浏览0评论

聊聊Spring AI Alibaba的RedisChatMemory

本文主要研究一下Spring AI Alibaba的RedisChatMemory

RedisChatMemory

community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/RedisChatMemory.java

代码语言:javascript代码运行次数:0运行复制
public class RedisChatMemory implements ChatMemory, AutoCloseable {

  private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);

  private static final String DEFAULT_KEY_PREFIX = "spring_ai_alibaba_chat_memory";

  private static final String DEFAULT_HOST = "127.0.0.1";

  private static final int DEFAULT_PORT = 6379;

  private static final String DEFAULT_PASSWORD = null;

  private final JedisPool jedisPool;

  private final Jedis jedis;

  private final ObjectMapper objectMapper;

  public RedisChatMemory() {

    this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);
  }

  public RedisChatMemory(String host, int port, String password) {

    JedisPoolConfig poolConfig = new JedisPoolConfig();

    this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);
    this.jedis = jedisPool.getResource();
    this.objectMapper = new ObjectMapper();
    SimpleModule module = new SimpleModule();
    module.addDeserializer(Message.class, new MessageDeserializer());
    this.objectMapper.registerModule(module);

    logger.info("Connected to Redis at {}:{}", host, port);
  }

  @Override
  public void add(String conversationId, List<Message> messages) {

    String key = DEFAULT_KEY_PREFIX + conversationId;

    for (Message message : messages) {
      try {
        String messageJson = objectMapper.writeValueAsString(message);
        jedis.rpush(key, messageJson);
      }
      catch (JsonProcessingException e) {
        throw new RuntimeException("Error serializing message", e);
      }
    }

    logger.info("Added messages to conversationId: {}", conversationId);
  }

  @Override
  public List<Message> get(String conversationId, int lastN) {

    String key = DEFAULT_KEY_PREFIX + conversationId;

    List<String> messageStrings = jedis.lrange(key, -lastN, -1);
    List<Message> messages = new ArrayList<>();

    for (String messageString : messageStrings) {
      try {
        Message message = objectMapper.readValue(messageString, Message.class);
        messages.add(message);
      }
      catch (JsonProcessingException e) {
        throw new RuntimeException("Error deserializing message", e);
      }
    }

    logger.info("Retrieved {} messages for conversationId: {}", messages.size(), conversationId);

    return messages;
  }

  @Override
  public void clear(String conversationId) {

    String key = DEFAULT_KEY_PREFIX + conversationId;

    jedis.del(key);
    logger.info("Cleared messages for conversationId: {}", conversationId);
  }

  @Override
  public void close() {

    if (jedis != null) {

      jedis.close();

      logger.info("Redis connection closed.");
    }
    if (jedisPool != null) {

      jedisPool.close();

      logger.info("Jedis pool closed.");
    }
  }

  public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {
    try {
      String key = DEFAULT_KEY_PREFIX + conversationId;

      List<String> all = jedis.lrange(key, 0, -1);

      if (all.size() >= maxLimit) {
        all = all.stream().skip(Math.max(0, deleteSize)).toList();
      }
      this.clear(conversationId);
      for (String message : all) {
        jedis.rpush(key, message);
      }
    }
    catch (Exception e) {
      logger.error("Error clearing messages from Redis chat memory", e);
      throw new RuntimeException(e);
    }
  }

  public void updateMessageById(String conversationId, String messages) {
    String key = "spring_ai_alibaba_chat_memory:" + conversationId;
    try {
      this.jedis.del(key);
      this.jedis.rpush(key, new String[] { messages });
    }
    catch (Exception var6) {
      logger.error("Error updating messages from Redis chat memory", var6);
      throw new RuntimeException(var6);
    }
  }

}

RedisChatMemory的构造器初始化了JedisPool并给ObjectMapper注册了org.springframework.ai.chat.messages.Message类型的MessageDeserializer;其add方法遍历messages挨个序列化为json然后rpush到spring_ai_alibaba_chat_memory{conversationId}中;其get方法通过lrange取出最近n条记录,再反序列化为message对象;其clear方法直接删除该key;close方法则先关闭jedis再关闭jedisPool

MessageDeserializer

community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/serializer/MessageDeserializer.java

代码语言:javascript代码运行次数:0运行复制
public class MessageDeserializer extends JsonDeserializer<Message> {

  private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class);

  public Message deserialize(JsonParser p, DeserializationContext ctxt) {
    ObjectMapper mapper = (ObjectMapper) p.getCodec();
    JsonNode node = null;
    Message message = null;
    try {
      node = mapper.readTree(p);
      String messageType = node.get("messageType").asText();
      switch (messageType) {
        case "USER" -> message = new UserMessage(node.get("text").asText(),
            mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {
            }), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {
            }));
        case "ASSISTANT" -> message = new AssistantMessage(node.get("text").asText(),
            mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {
            }), (List<AssistantMessage.ToolCall>) mapper.convertValue(node.get("toolCalls"),
                new TypeReference<Collection<AssistantMessage.ToolCall>>() {
                }),
            (List<Media>) mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {
            }));
        default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
      }
      ;
    }
    catch (IOException e) {
      logger.error("Error deserializing message", e);
    }
    return message;
  }

}

MessageDeserializer继承了JsonDeserializer,它读取messageType字段,然后对于USER类型创建UserMessage、对于ASSISTANT类型创建AssistantMessage

小结

spring-ai-alibaba-redis-memory提供了ChatMemory的redis实现,它通过jedis使用rpush添加message,通过lrange取出最近N条,通过del删除指定会话的消息。

doc

  • java2ai
发布评论

评论列表(0)

  1. 暂无评论