Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.microsoft.semantickernel.aiservices.openai.OpenAiService;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.responseformat.ChatCompletionsJsonSchemaResponseFormat;
import com.microsoft.semantickernel.aiservices.openai.implementation.OpenAIRequestSettings;
import com.microsoft.semantickernel.contents.FunctionCallContent;
import com.microsoft.semantickernel.contextvariables.ContextVariable;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
import com.microsoft.semantickernel.exceptions.AIException;
Expand Down Expand Up @@ -468,7 +469,6 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
.doOnTerminate(span::close);
})
.flatMap(completions -> {

List<ChatResponseMessage> responseMessages = completions
.getChoices()
.stream()
Expand All @@ -488,6 +488,7 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}

// Or if there are no tool calls to be done
ChatResponseMessage response = responseMessages.get(0);
List<ChatCompletionsToolCall> toolCalls = response.getToolCalls();
Expand Down Expand Up @@ -633,21 +634,21 @@ private Mono<FunctionResult<String>> invokeFunctionTool(
ContextVariableTypes contextVariableTypes) {

try {
OpenAIFunctionToolCall openAIFunctionToolCall = extractOpenAIFunctionToolCall(toolCall);
String pluginName = openAIFunctionToolCall.getPluginName();
FunctionCallContent functionCallContent = extractFunctionCallContent(toolCall);
String pluginName = functionCallContent.getPluginName();
if (pluginName == null || pluginName.isEmpty()) {
return Mono.error(
new SKException("Plugin name is required for function tool call"));
}

KernelFunction<?> function = kernel.getFunction(
pluginName,
openAIFunctionToolCall.getFunctionName());
functionCallContent.getFunctionName());

PreToolCallEvent hookResult = executeHook(invocationContext, kernel,
new PreToolCallEvent(
openAIFunctionToolCall.getFunctionName(),
openAIFunctionToolCall.getArguments(),
functionCallContent.getFunctionName(),
functionCallContent.getArguments(),
function,
contextVariableTypes));

Expand Down Expand Up @@ -686,7 +687,7 @@ private static <T extends KernelHookEvent> T executeHook(
}

@SuppressWarnings("StringSplitter")
private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
private FunctionCallContent extractFunctionCallContent(
ChatCompletionsFunctionToolCall toolCall)
throws JsonProcessingException {

Expand All @@ -712,10 +713,10 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
}
});

return new OpenAIFunctionToolCall(
toolCall.getId(),
pluginName,
return new FunctionCallContent(
fnName,
pluginName,
toolCall.getId(),
arguments);
}

Expand Down Expand Up @@ -744,7 +745,7 @@ private List<OpenAIChatMessageContent<?>> getChatMessageContentsAsync(
null,
null,
completionMetadata,
formOpenAiToolCalls(response));
formFunctionCallContents(response));
} catch (SKCheckedException e) {
LOGGER.warn("Failed to form chat message content", e);
return null;
Expand Down Expand Up @@ -784,7 +785,7 @@ private List<ChatMessageContent<?>> toOpenAIChatMessageContent(
null);
} else if (message instanceof ChatRequestAssistantMessage) {
try {
List<OpenAIFunctionToolCall> calls = getToolCalls(
List<FunctionCallContent> calls = getFunctionCallContents(
((ChatRequestAssistantMessage) message).getToolCalls());
return new OpenAIChatMessageContent<>(
AuthorRole.ASSISTANT,
Expand Down Expand Up @@ -823,7 +824,7 @@ private List<ChatMessageContent<?>> toOpenAIChatMessageContent(
}

@Nullable
private List<OpenAIFunctionToolCall> getToolCalls(
private List<FunctionCallContent> getFunctionCallContents(
@Nullable List<ChatCompletionsToolCall> toolCalls) throws SKCheckedException {
if (toolCalls == null || toolCalls.isEmpty()) {
return null;
Expand All @@ -835,7 +836,7 @@ private List<OpenAIFunctionToolCall> getToolCalls(
.map(call -> {
if (call instanceof ChatCompletionsFunctionToolCall) {
try {
return extractOpenAIFunctionToolCall(
return extractFunctionCallContent(
(ChatCompletionsFunctionToolCall) call);
} catch (JsonProcessingException e) {
throw SKException.build("Failed to parse tool arguments", e);
Expand All @@ -852,7 +853,7 @@ private List<OpenAIFunctionToolCall> getToolCalls(
}

@Nullable
private List<OpenAIFunctionToolCall> formOpenAiToolCalls(
private List<FunctionCallContent> formFunctionCallContents(
ChatResponseMessage response) throws SKCheckedException {
if (response.getToolCalls() == null || response.getToolCalls().isEmpty()) {
return null;
Expand All @@ -864,7 +865,7 @@ private List<OpenAIFunctionToolCall> formOpenAiToolCalls(
.map(call -> {
if (call instanceof ChatCompletionsFunctionToolCall) {
try {
return extractOpenAIFunctionToolCall(
return extractFunctionCallContent(
(ChatCompletionsFunctionToolCall) call);
} catch (JsonProcessingException e) {
throw SKException.build("Failed to parse tool arguments", e);
Expand Down Expand Up @@ -1251,10 +1252,7 @@ private static ChatRequestAssistantMessage formAssistantMessage(
// TODO: handle tools other than function calls
ChatRequestAssistantMessage asstMessage = new ChatRequestAssistantMessage(content);

List<OpenAIFunctionToolCall> toolCalls = null;
if (message instanceof OpenAIChatMessageContent) {
toolCalls = ((OpenAIChatMessageContent<?>) message).getToolCall();
}
List<FunctionCallContent> toolCalls = FunctionCallContent.getFunctionCalls(message);

if (toolCalls != null) {
asstMessage.setToolCalls(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;

import com.microsoft.semantickernel.contents.FunctionCallContent;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.services.KernelContent;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/**
Expand All @@ -16,6 +19,7 @@
*/
public class OpenAIChatMessageContent<T> extends ChatMessageContent<T> {

@Deprecated
@Nullable
private final List<OpenAIFunctionToolCall> toolCall;

Expand All @@ -28,7 +32,7 @@ public class OpenAIChatMessageContent<T> extends ChatMessageContent<T> {
* @param innerContent The inner content.
* @param encoding The encoding.
* @param metadata The metadata.
* @param toolCall The tool call.
* @param functionCalls The tool call.
*/
public OpenAIChatMessageContent(
AuthorRole authorRole,
Expand All @@ -37,21 +41,36 @@ public OpenAIChatMessageContent(
@Nullable T innerContent,
@Nullable Charset encoding,
@Nullable FunctionResultMetadata<?> metadata,
@Nullable List<OpenAIFunctionToolCall> toolCall) {
super(authorRole, content, modelId, innerContent, encoding, metadata);
@Nullable List<? extends FunctionCallContent> functionCalls) {
super(authorRole, content, (List<? extends KernelContent<T>>) functionCalls, modelId,
innerContent, encoding, metadata);

if (toolCall == null) {
if (functionCalls == null) {
this.toolCall = null;
} else {
this.toolCall = Collections.unmodifiableList(toolCall);
// Keep OpenAIFunctionToolCall list for legacy
this.toolCall = Collections.unmodifiableList(functionCalls.stream().map(t -> {
if (t instanceof OpenAIFunctionToolCall) {
return (OpenAIFunctionToolCall) t;
} else {
return new OpenAIFunctionToolCall(
t.getId(),
t.getPluginName(),
t.getFunctionName(),
t.getArguments());
}
}).collect(Collectors.toList()));
}
}

/**
* Gets any tool calls requested.
*
* @return The tool call.
*
* @deprecated Use {@link FunctionCallContent#getFunctionCalls(ChatMessageContent)} instead.
*/
@Deprecated
@Nullable
public List<OpenAIFunctionToolCall> getToolCall() {
return toolCall;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;

import com.microsoft.semantickernel.contents.FunctionCallContent;
import com.microsoft.semantickernel.semanticfunctions.KernelArguments;
import javax.annotation.Nullable;

/**
* Represents a call to a function in the OpenAI tool.
*
* @deprecated Use {@link FunctionCallContent} instead.
*/
public class OpenAIFunctionToolCall {

/// <summary>Gets the ID of the tool call.</summary>
@Nullable
private final String id;

/// <summary>Gets the name of the plugin with which this function is associated, if any.</summary>

@Nullable
private final String pluginName;

/// <summary>Gets the name of the function.</summary>
private final String functionName;

/// <summary>Gets a name/value collection of the arguments to the function, if any.</summary>
@Nullable
private final KernelArguments arguments;
@Deprecated
public class OpenAIFunctionToolCall extends FunctionCallContent {

/**
* Creates a new instance of the {@link OpenAIFunctionToolCall} class.
Expand All @@ -38,55 +26,6 @@ public OpenAIFunctionToolCall(
@Nullable String pluginName,
String functionName,
@Nullable KernelArguments arguments) {
this.id = id;
this.pluginName = pluginName;
this.functionName = functionName;
if (arguments == null) {
this.arguments = null;
} else {
this.arguments = arguments.copy();
}
}

/**
* Gets the ID of the tool call.
*
* @return The ID of the tool call.
*/
@Nullable
public String getId() {
return id;
}

/**
* Gets the name of the plugin with which this function is associated, if any.
*
* @return The name of the plugin with which this function is associated, if any.
*/
@Nullable
public String getPluginName() {
return pluginName;
}

/**
* Gets the name of the function.
*
* @return The name of the function.
*/
public String getFunctionName() {
return functionName;
}

/**
* Gets a name/value collection of the arguments to the function, if any.
*
* @return A name/value collection of the arguments to the function, if any.
*/
@Nullable
public KernelArguments getArguments() {
if (arguments == null) {
return null;
}
return arguments.copy();
super(functionName, pluginName, id, arguments);
}
}
Loading