最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

聊聊Spring AI的Tool Calling

网站源码admin4浏览0评论

聊聊Spring AI的Tool Calling

本文主要研究一下Spring AI的Tool Calling

ToolCallback

org/springframework/ai/tool/ToolCallback.java

代码语言:javascript代码运行次数:0运行复制
public interface ToolCallback extends FunctionCallback {

	/**
	 * Definition used by the AI model to determine when and how to call the tool.
	 */
	ToolDefinition getToolDefinition();

	/**
	 * Metadata providing additional information on how to handle the tool.
	 */
	default ToolMetadata getToolMetadata() {
		return ToolMetadata.builder().build();
	}

	/**
	 * Execute tool with the given input and return the result to send back to the AI
	 * model.
	 */
	String call(String toolInput);

	/**
	 * Execute tool with the given input and context, and return the result to send back
	 * to the AI model.
	 */
	default String call(String toolInput, @Nullable ToolContext tooContext) {
		if (tooContext != null && !tooContext.getContext().isEmpty()) {
			throw new UnsupportedOperationException("Tool context is not supported!");
		}
		return call(toolInput);
	}

	@Override
	@Deprecated // Call getToolDefinition().name() instead
	default String getName() {
		return getToolDefinition().name();
	}

	@Override
	@Deprecated // Call getToolDefinition().description() instead
	default String getDescription() {
		return getToolDefinition().description();
	}

	@Override
	@Deprecated // Call getToolDefinition().inputTypeSchema() instead
	default String getInputTypeSchema() {
		return getToolDefinition().inputSchema();
	}

}

ToolCallback继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback

MethodToolCallback

org/springframework/ai/tool/method/MethodToolCallback.java

代码语言:javascript代码运行次数:0运行复制
public class MethodToolCallback implements ToolCallback {

	private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);

	private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();

	private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

	private final ToolDefinition toolDefinition;

	private final ToolMetadata toolMetadata;

	private final Method toolMethod;

	@Nullable
	private final Object toolObject;

	private final ToolCallResultConverter toolCallResultConverter;

	public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
			@Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
		Assert.notNull(toolDefinition, "toolDefinition cannot be null");
		Assert.notNull(toolMethod, "toolMethod cannot be null");
		Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
				"toolObject cannot be null for non-static methods");
		this.toolDefinition = toolDefinition;
		this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
		this.toolMethod = toolMethod;
		this.toolObject = toolObject;
		this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
				: DEFAULT_RESULT_CONVERTER;
	}

	@Override
	public ToolDefinition getToolDefinition() {
		return toolDefinition;
	}

	@Override
	public ToolMetadata getToolMetadata() {
		return toolMetadata;
	}

	@Override
	public String call(String toolInput) {
		return call(toolInput, null);
	}

	@Override
	public String call(String toolInput, @Nullable ToolContext toolContext) {
		Assert.hasText(toolInput, "toolInput cannot be null or empty");

		logger.debug("Starting execution of tool: {}", toolDefinition.name());

		validateToolContextSupport(toolContext);

		Map<String, Object> toolArguments = extractToolArguments(toolInput);

		Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);

		Object result = callMethod(methodArguments);

		logger.debug("Successful execution of tool: {}", toolDefinition.name());

		Type returnType = toolMethod.getGenericReturnType();

		return toolCallResultConverter.convert(result, returnType);
	}

	@Nullable
	private Object callMethod(Object[] methodArguments) {
		if (isObjectNotPublic() || isMethodNotPublic()) {
			toolMethod.setAccessible(true);
		}

		Object result;
		try {
			result = toolMethod.invoke(toolObject, methodArguments);
		}
		catch (IllegalAccessException ex) {
			throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
		}
		catch (InvocationTargetException ex) {
			throw new ToolExecutionException(toolDefinition, ex.getCause());
		}
		return result;
	}

	//......
}	

MethodToolCallback实现了ToolCallback接口,其call方法通过buildMethodArguments构建参数,再通过callMethod获取返回值,最后通过toolCallResultConverter.convert来转换返回值类型;callMethod主要是通过反射调用执行 目前如下几个类型作为参数或者返回类型不支持

  • Optional
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux)
  • Functional types (e.g. Function, Supplier, Consumer).

FunctionToolCallback

org/springframework/ai/tool/function/FunctionToolCallback.java

代码语言:javascript代码运行次数:0运行复制
public class FunctionToolCallback<I, O> implements ToolCallback {

	private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);

	private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();

	private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

	private final ToolDefinition toolDefinition;

	private final ToolMetadata toolMetadata;

	private final Type toolInputType;

	private final BiFunction<I, ToolContext, O> toolFunction;

	private final ToolCallResultConverter toolCallResultConverter;

	public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
			BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
		Assert.notNull(toolDefinition, "toolDefinition cannot be null");
		Assert.notNull(toolInputType, "toolInputType cannot be null");
		Assert.notNull(toolFunction, "toolFunction cannot be null");
		this.toolDefinition = toolDefinition;
		this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
		this.toolFunction = toolFunction;
		this.toolInputType = toolInputType;
		this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
				: DEFAULT_RESULT_CONVERTER;
	}

	@Override
	public ToolDefinition getToolDefinition() {
		return toolDefinition;
	}

	@Override
	public ToolMetadata getToolMetadata() {
		return toolMetadata;
	}

	@Override
	public String call(String toolInput) {
		return call(toolInput, null);
	}

	@Override
	public String call(String toolInput, @Nullable ToolContext toolContext) {
		Assert.hasText(toolInput, "toolInput cannot be null or empty");

		logger.debug("Starting execution of tool: {}", toolDefinition.name());

		I request = JsonParser.fromJson(toolInput, toolInputType);
		O response = toolFunction.apply(request, toolContext);

		logger.debug("Successful execution of tool: {}", toolDefinition.name());

		return toolCallResultConverter.convert(response, null);
	}

	@Override
	public String toString() {
		return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
	}

	//......
}	

FunctionToolCallback实现了ToolCallback接口,其call方法通过JsonParser.fromJson(toolInput, toolInputType)转换请求参数,再通过toolFunction.apply(request, toolContext)获取返回结果,最后通过toolCallResultConverter.convert(response, null)来转换结果 目前如下类型不支持作为参数或者返回类型

  • Primitive types
  • Optional
  • Collection types (e.g. List, Map, Array, Set)
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux).

示例

代码语言:javascript代码运行次数:0运行复制
class DateTimeTools {

    String getCurrentDateTime() {
        return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
    }

}

MethodToolCallback

代码语言:javascript代码运行次数:0运行复制
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolCallback toolCallback = MethodToolCallback.builder()
    .toolDefinition(ToolDefinition.builder(method)
            .description("Get the current date and time in the user's timezone")
            .build())
    .toolMethod(method)
    .toolObject(new DateTimeTools())
    .build();

亦或是使用@Tool注解

代码语言:javascript代码运行次数:0运行复制
class DateTimeTools {

    @Tool(description = "Get the current date and time in the user's timezone")
    String getCurrentDateTime() {
        return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
    }

}

亦或是通过ToolCallbacks.from方法

代码语言:javascript代码运行次数:0运行复制
ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools());

FunctionToolCallback

代码语言:javascript代码运行次数:0运行复制
public class WeatherService implements Function<WeatherRequest, WeatherResponse> {
    public WeatherResponse apply(WeatherRequest request) {
        return new WeatherResponse(30.0, Unit.C);
    }
}

ToolCallback toolCallback = FunctionToolCallback
    .builder("currentWeather", new WeatherService())
    .description("Get the weather in location")
    .inputType(WeatherRequest.class)
    .build();

ChatClient.create(chatModel)
    .prompt("What's the weather like in Copenhagen?")
    .tools(toolCallback)
    .call()
    .content();    

亦或设置到chatOptions

代码语言:javascript代码运行次数:0运行复制
ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(toolCallback)
    .build():
Prompt prompt = new Prompt("What's the weather like in Copenhagen?", chatOptions);
chatModel.call(prompt);

亦或是注册到spring中

代码语言:javascript代码运行次数:0运行复制
@Configuration(proxyBeanMethods = false)
class WeatherTools {

    WeatherService weatherService = new WeatherService();

	@Bean
	@Description("Get the weather in location")
	Function<WeatherRequest, WeatherResponse> currentWeather() {
		return weatherService;
	}

}

ChatClient.create(chatModel)
    .prompt("What's the weather like in Copenhagen?")
    .tools("currentWeather")
    .call()
    .content();

Tool Specification

ToolDefinition

org/springframework/ai/tool/definition/ToolDefinition.java

代码语言:javascript代码运行次数:0运行复制
public interface ToolDefinition {

	/**
	 * The tool name. Unique within the tool set provided to a model.
	 */
	String name();

	/**
	 * The tool description, used by the AI model to determine what the tool does.
	 */
	String description();

	/**
	 * The schema of the parameters used to call the tool.
	 */
	String inputSchema();

	/**
	 * Create a default {@link ToolDefinition} builder.
	 */
	static DefaultToolDefinition.Builder builder() {
		return DefaultToolDefinition.builder();
	}

	/**
	 * Create a default {@link ToolDefinition} builder from a {@link Method}.
	 */
	static DefaultToolDefinition.Builder builder(Method method) {
		Assert.notNull(method, "method cannot be null");
		return DefaultToolDefinition.builder()
			.name(ToolUtils.getToolName(method))
			.description(ToolUtils.getToolDescription(method))
			.inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
	}

	/**
	 * Create a default {@link ToolDefinition} instance from a {@link Method}.
	 */
	static ToolDefinition from(Method method) {
		return ToolDefinition.builder(method).build();
	}

}

ToolDefinition定义了name、description、inputSchema属性,它提供了builder方法可以基于Method来构建DefaultToolDefinition

示例

代码语言:javascript代码运行次数:0运行复制
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolDefinition toolDefinition = ToolDefinition.builder(method)
    .name("currentDateTime")
    .description("Get the current date and time in the user's timezone")
    .inputSchema(JsonSchemaGenerator.generateForMethodInput(method))
    .build();

JSON Schema

Spring AI提供了JsonSchemaGenerator用于生成指定method或者function的请求参数的json schema,对于参数描述可以使用如下注解:

代码语言:javascript代码运行次数:0运行复制
@ToolParam(description = "…") from Spring AI
@JsonClassDescription(description = "…") from Jackson
@JsonPropertyDescription(description = "…") from Jackson
@Schema(description = "…") from Swagger.

示例

代码语言:javascript代码运行次数:0运行复制
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.context.i18n.LocaleContextHolder;

class DateTimeTools {

    @Tool(description = "Set a user alarm for the given time")
    void setAlarm(@ToolParam(description = "Time in ISO-8601 format") String time) {
        LocalDateTime alarmTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_DATE_TIME);
        System.out.println("Alarm set for " + alarmTime);
    }

}

对于是否必填,可以使用如下注解:

代码语言:javascript代码运行次数:0运行复制
@ToolParam(required = false) from Spring AI
@JsonProperty(required = false) from Jackson
@Schema(required = false) from Swagger
@Nullable from Spring Framework.

示例:

代码语言:javascript代码运行次数:0运行复制
class CustomerTools {

    @Tool(description = "Update customer information")
    void updateCustomerInfo(Long id, String name, @ToolParam(required = false) String email) {
        System.out.println("Updated info for customer with id: " + id);
    }

}

Result Conversion

Spring AI提供了ToolCallResultConverter用于将tool calling的返回数据进行转换再发送给AI模型 org/springframework/ai/tool/execution/ToolCallResultConverter.java

代码语言:javascript代码运行次数:0运行复制
@FunctionalInterface
public interface ToolCallResultConverter {

	/**
	 * Given an Object returned by a tool, convert it to a String compatible with the
	 * given class type.
	 */
	String convert(@Nullable Object result, @Nullable Type returnType);

}

它有一个默认实现DefaultToolCallResultConverter

代码语言:javascript代码运行次数:0运行复制
public final class DefaultToolCallResultConverter implements ToolCallResultConverter {

	private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);

	@Override
	public String convert(@Nullable Object result, @Nullable Type returnType) {
		if (returnType == Void.TYPE) {
			logger.debug("The tool has no return type. Converting to conventional response.");
			return "Done";
		}
		else {
			logger.debug("Converting tool result to JSON.");
			return JsonParser.toJson(result);
		}
	}

}

DefaultToolCallResultConverter采用的是JsonParser.toJson(result),将返回类型转换为json字符串

也可以自己指定,比如

代码语言:javascript代码运行次数:0运行复制
class CustomerTools {

    @Tool(description = "Retrieve customer information", resultConverter = CustomToolCallResultConverter.class)
    Customer getCustomerInfo(Long id) {
        return customerRepository.findById(id);
    }

}

Tool Context

Spring AI提供了ToolContext,可以将附加的上下文信息传递给工具。这一功能允许开发者提供额外的、由用户提供的数据,这些数据可以在工具执行过程中与AI模型传递的工具参数一起使用。使用示例如下:

代码语言:javascript代码运行次数:0运行复制
class CustomerTools {

    @Tool(description = "Retrieve customer information")
    Customer getCustomerInfo(Long id, ToolContext toolContext) {
        return customerRepository.findById(id, toolContext.get("tenantId"));
    }

}

对于chatClient:

代码语言:javascript代码运行次数:0运行复制
ChatModel chatModel = ...

String response = ChatClient.create(chatModel)
        .prompt("Tell me more about the customer with ID 42")
        .tools(new CustomerTools())
        .toolContext(Map.of("tenantId", "acme"))
        .call()
        .content();

System.out.println(response);

对于chatModel:

代码语言:javascript代码运行次数:0运行复制
ChatModel chatModel = ...
ToolCallback[] customerTools = ToolCallbacks.from(new CustomerTools());
ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(customerTools)
    .toolContext(Map.of("tenantId", "acme"))
    .build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
chatModel.call(prompt);

Return Direct

Spring AI提供了returnDirect参数,设置为true则会将tool calling的返回直接返回,而不是经过大模型再返回。默认是返回给AI模型,AI模型处理之后再返回给用户。 示例如下:

代码语言:javascript代码运行次数:0运行复制
class CustomerTools {

    @Tool(description = "Retrieve customer information", returnDirect = true)
    Customer getCustomerInfo(Long id) {
        return customerRepository.findById(id);
    }

}

亦或是

代码语言:javascript代码运行次数:0运行复制
ToolMetadata toolMetadata = ToolMetadata.builder()
    .returnDirect(true)
    .build();

ToolCallingManager

org/springframework/ai/model/tool/ToolCallingManager.java

代码语言:javascript代码运行次数:0运行复制
public interface ToolCallingManager {

	/**
	 * Resolve the tool definitions from the model's tool calling options.
	 */
	List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions);

	/**
	 * Execute the tool calls requested by the model.
	 */
	ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);

	/**
	 * Create a default {@link ToolCallingManager} builder.
	 */
	static DefaultToolCallingManager.Builder builder() {
		return DefaultToolCallingManager.builder();
	}

}

ToolCallingManager定义了resolveToolDefinitions、executeToolCalls方法,默认实现是DefaultToolCallingManager

DefaultToolCallingManager

org/springframework/ai/model/tool/DefaultToolCallingManager.java

代码语言:javascript代码运行次数:0运行复制
public class DefaultToolCallingManager implements ToolCallingManager {

	private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);

	// @formatter:off

	private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY
			= ObservationRegistry.NOOP;

	private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
			= new DelegatingToolCallbackResolver(List.of());

	private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
			= DefaultToolExecutionExceptionProcessor.builder().build();

	// @formatter:on

	private final ObservationRegistry observationRegistry;

	private final ToolCallbackResolver toolCallbackResolver;

	private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

	public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
			ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
		Assert.notNull(observationRegistry, "observationRegistry cannot be null");
		Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
		Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

		this.observationRegistry = observationRegistry;
		this.toolCallbackResolver = toolCallbackResolver;
		this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
	}

	@Override
	public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
		Assert.notNull(chatOptions, "chatOptions cannot be null");

		List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
		for (String toolName : chatOptions.getToolNames()) {
			// Skip the tool if it is already present in the request toolCallbacks.
			// That might happen if a tool is defined in the options
			// both as a ToolCallback and as a tool name.
			if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
				continue;
			}
			FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
			if (toolCallback == null) {
				throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
			}
			toolCallbacks.add(toolCallback);
		}

		return toolCallbacks.stream().map(functionCallback -> {
			if (functionCallback instanceof ToolCallback toolCallback) {
				return toolCallback.getToolDefinition();
			}
			else {
				return ToolDefinition.builder()
					.name(functionCallback.getName())
					.description(functionCallback.getDescription())
					.inputSchema(functionCallback.getInputTypeSchema())
					.build();
			}
		}).toList();
	}

	@Override
	public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
		Assert.notNull(prompt, "prompt cannot be null");
		Assert.notNull(chatResponse, "chatResponse cannot be null");

		Optional<Generation> toolCallGeneration = chatResponse.getResults()
			.stream()
			.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
			.findFirst();

		if (toolCallGeneration.isEmpty()) {
			throw new IllegalStateException("No tool call requested by the chat model");
		}

		AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();

		ToolContext toolContext = buildToolContext(prompt, assistantMessage);

		InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
				toolContext);

		List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
				assistantMessage, internalToolExecutionResult.toolResponseMessage());

		return ToolExecutionResult.builder()
			.conversationHistory(conversationHistory)
			.returnDirect(internalToolExecutionResult.returnDirect())
			.build();
	}

	//......

	/**
	 * Execute the tool call and return the response message. To ensure backward
	 * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are
	 * supported.
	 */
	private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
			ToolContext toolContext) {
		List<FunctionCallback> toolCallbacks = List.of();
		if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
			toolCallbacks = toolCallingChatOptions.getToolCallbacks();
		}
		else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) {
			toolCallbacks = functionOptions.getFunctionCallbacks();
		}

		List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

		Boolean returnDirect = null;

		for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

			logger.debug("Executing tool call: {}", toolCall.name());

			String toolName = toolCall.name();
			String toolInputArguments = toolCall.arguments();

			FunctionCallback toolCallback = toolCallbacks.stream()
				.filter(tool -> toolName.equals(tool.getName()))
				.findFirst()
				.orElseGet(() -> toolCallbackResolver.resolve(toolName));

			if (toolCallback == null) {
				throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
			}

			if (returnDirect == null && toolCallback instanceof ToolCallback callback) {
				returnDirect = callback.getToolMetadata().returnDirect();
			}
			else if (toolCallback instanceof ToolCallback callback) {
				returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
			}
			else if (returnDirect == null) {
				// This is a temporary solution to ensure backward compatibility with
				// FunctionCallback.
				// TODO: remove this block when FunctionCallback is removed.
				returnDirect = false;
			}

			String toolResult;
			try {
				toolResult = toolCallback.call(toolInputArguments, toolContext);
			}
			catch (ToolExecutionException ex) {
				toolResult = toolExecutionExceptionProcessor.process(ex);
			}

			toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
		}

		return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
	}

	private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
			AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
		List<Message> messages = new ArrayList<>(previousMessages);
		messages.add(assistantMessage);
		messages.add(toolResponseMessage);
		return messages;
	}	
}	

DefaultToolCallingManager的resolveToolDefinitions方法会通过toolCallbackResolver来解析chatOptions.getToolCallbacks(),executeToolCalls方法先筛选出需要toolCall支持的assistantMessage,然后构建toolContext,再执行executeToolCall获取执行结构,再基于此构建conversationHistory。 executeToolCall方法遍历assistantMessage.getToolCalls(),通过toolCallbackResolver.resolve(toolName)解析成toolCallback,最后通过toolCallback.call(toolInputArguments, toolContext)获取结果,如果出现ToolExecutionException,则通过toolExecutionExceptionProcessor.process(ex)去做兜底操作

ToolExecutionExceptionProcessor

org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java

代码语言:javascript代码运行次数:0运行复制
@FunctionalInterface
public interface ToolExecutionExceptionProcessor {

	/**
	 * Convert an exception thrown by a tool to a String that can be sent back to the AI
	 * model or throw an exception to be handled by the caller.
	 */
	String process(ToolExecutionException exception);

}

ToolExecutionExceptionProcessor定义process

DefaultToolExecutionExceptionProcessor

代码语言:javascript代码运行次数:0运行复制
public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {

	private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);

	private static final boolean DEFAULT_ALWAYS_THROW = false;

	private final boolean alwaysThrow;

	public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
		this.alwaysThrow = alwaysThrow;
	}

	@Override
	public String process(ToolExecutionException exception) {
		Assert.notNull(exception, "exception cannot be null");
		if (alwaysThrow) {
			throw exception;
		}
		logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(),
				exception.getMessage());
		return exception.getMessage();
	}

	//......
}	

DefaultToolExecutionExceptionProcessor对于alwaysThrow为true的(默认为false)直接抛出该异常,否则返回异常的信息

User-Controlled Tool Execution

ToolCallingChatOptions提供了internalToolExecutionEnabled属性,设置为false可以自行控制对tool的调用过程(也可以自己实现ToolExecutionEligibilityPredicate去控制),示例如下:

代码语言:javascript代码运行次数:0运行复制
ChatModel chatModel = ...
ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();

ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(new CustomerTools())
    .internalToolExecutionEnabled(false)
    .build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);

ChatResponse chatResponse = chatModel.call(prompt);

while (chatResponse.hasToolCalls()) {
    ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

    prompt = new Prompt(toolExecutionResult.conversationHistory(), chatOptions);

    chatResponse = chatModel.call(prompt);
}

System.out.println(chatResponse.getResult().getOutput().getText());

这里自己通过toolCallingManager.executeToolCalls去执行,再传递给chatModel

ToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java

代码语言:javascript代码运行次数:0运行复制
public interface ToolCallbackResolver {

	/**
	 * Resolve the {@link FunctionCallback} for the given tool name.
	 */
	@Nullable
	FunctionCallback resolve(String toolName);

}

ToolCallbackResolver定义了resolve方法,用于根据toolName来获取对应的FunctionCallback,它有三种实现,分别是StaticToolCallbackResolver、SpringBeanToolCallbackResolver、DelegatingToolCallbackResolver

StaticToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java

代码语言:javascript代码运行次数:0运行复制
public class StaticToolCallbackResolver implements ToolCallbackResolver {

	private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class);

	private final Map<String, FunctionCallback> toolCallbacks = new HashMap<>();

	public StaticToolCallbackResolver(List<FunctionCallback> toolCallbacks) {
		Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
		Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");

		toolCallbacks.forEach(callback -> {
			if (callback instanceof ToolCallback toolCallback) {
				this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback);
			}
			this.toolCallbacks.put(callback.getName(), callback);
		});
	}

	@Override
	public FunctionCallback resolve(String toolName) {
		Assert.hasText(toolName, "toolName cannot be null or empty");
		logger.debug("ToolCallback resolution attempt from static registry");
		return toolCallbacks.get(toolName);
	}

}

StaticToolCallbackResolver依据构造器传入的List<FunctionCallback>来寻找

SpringBeanToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java

代码语言:javascript代码运行次数:0运行复制
public class SpringBeanToolCallbackResolver implements ToolCallbackResolver {

	private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class);

	private static final Map<String, ToolCallback> toolCallbacksCache = new HashMap<>();

	private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA;

	private final GenericApplicationContext applicationContext;

	private final SchemaType schemaType;

	public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext,
			@Nullable SchemaType schemaType) {
		Assert.notNull(applicationContext, "applicationContext cannot be null");

		this.applicationContext = applicationContext;
		this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE;
	}

	@Override
	public ToolCallback resolve(String toolName) {
		Assert.hasText(toolName, "toolName cannot be null or empty");

		logger.debug("ToolCallback resolution attempt from Spring application context");

		ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName);

		if (resolvedToolCallback != null) {
			return resolvedToolCallback;
		}

		ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName);
		ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType))
				? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0);

		String toolDescription = resolveToolDescription(toolName, toolInputType.toClass());
		Object bean = applicationContext.getBean(toolName);

		resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean);

		toolCallbacksCache.put(toolName, resolvedToolCallback);

		return resolvedToolCallback;
	}

	//......
}	

SpringBeanToolCallbackResolver使用GenericApplicationContext根据toolName去spring容器查找,找到的话会放到toolCallbacksCache中

DelegatingToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java

代码语言:javascript代码运行次数:0运行复制
public class DelegatingToolCallbackResolver implements ToolCallbackResolver {

	private final List<ToolCallbackResolver> toolCallbackResolvers;

	public DelegatingToolCallbackResolver(List<ToolCallbackResolver> toolCallbackResolvers) {
		Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null");
		Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements");
		this.toolCallbackResolvers = toolCallbackResolvers;
	}

	@Override
	@Nullable
	public FunctionCallback resolve(String toolName) {
		Assert.hasText(toolName, "toolName cannot be null or empty");

		for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
			FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
			if (toolCallback != null) {
				return toolCallback;
			}
		}
		return null;
	}

}

DelegatingToolCallbackResolver把resolve方法委托给了构造器传入的其他toolCallbackResolvers

小结

Spring AI提供了ToolCallback来实现Tool Calling,它继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback。

整个Tool Specification包含了Tool Callback、Tool Definition、JSON Schema、Result Conversion、Tool Context、Return Direct 整个Tool Execution包含了Framework-Controlled Tool Execution、User-Controlled Tool Execution、Exception Handling

doc

  • tools
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。原始发表:2025-04-11,如有侵权请联系 cloudcommunity@tencent 删除returnstring接口springnull

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论