Source code for lwe.backends.api.request
import copy
from langchain_community.adapters.openai import convert_message_to_dict
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from lwe.core.logger import Logger
from lwe.core import constants
import lwe.core.util as util
from lwe.core.tool_cache import ToolCache
from lwe.core.token_manager import TokenManager
from lwe.backends.api.orm import Orm
from lwe.backends.api.message import MessageManager
[docs]
class ApiRequest:
"""Individual LLM requests manager"""
def __init__(
self,
config=None,
provider=None,
provider_manager=None,
tool_manager=None,
input=None,
preset=None,
preset_manager=None,
system_message=None,
old_messages=None,
max_submission_tokens=None,
request_overrides=None,
return_only=False,
orm=None,
):
self.config = config
self.log = Logger(self.__class__.__name__, self.config)
self.default_provider = provider
self.provider = self.default_provider
self.provider_manager = provider_manager
self.tool_manager = tool_manager
self.input = input
self.default_preset = preset
self.default_preset_name = util.get_preset_name(self.default_preset)
self.preset_manager = preset_manager
self.system_message = system_message or constants.SYSTEM_MESSAGE_DEFAULT
self.old_messages = old_messages or []
self.max_submission_tokens = (
max_submission_tokens or constants.OPEN_AI_DEFAULT_MAX_SUBMISSION_TOKENS
)
self.request_overrides = request_overrides or {}
self.return_only = return_only
self.orm = orm or Orm(self.config)
self.message = MessageManager(config, self.orm)
self.streaming = False
self.log.debug(
f"Inintialized ApiRequest with input: {self.input}, default preset name: {self.default_preset_name}, system_message: {self.system_message}, max_submission_tokens: {self.max_submission_tokens}, request_overrides: {self.request_overrides}, return only: {self.return_only}"
)
[docs]
def set_request_llm(self):
success, response, user_message = self.extract_metadata_customizations()
if not success:
return success, response, user_message
preset_name, preset_overrides, metadata, customizations = response
success, response, user_message = self.setup_request_config(
preset_name, preset_overrides, metadata, customizations
)
return success, response, user_message
[docs]
def setup_request_config(
self, preset_name=None, preset_overrides=None, metadata=None, customizations=None
):
"""
Set up the configuration for the request.
:param preset_name: Override preset name
:type preset_name: str, optional
:param preset_overrides: Overrides for preset, defaults to None
:type preset_overrides: dict, optional
:param metadata: Preset metadata
:type metadata: dict, optional
:param customizations: Provider/model customizations
:type customizations: dict, optional
:returns: success, llm, message
:rtype: tuple
"""
config = {
"preset_name": preset_name,
"preset_overrides": preset_overrides,
"metadata": metadata,
"customizations": customizations,
}
success, response, user_message = self.build_request_config(config)
if success:
provider, preset, llm, preset_name, model_name, token_manager = response
self.llm = llm
self.provider = provider
self.token_manager = token_manager
self.preset_name = preset_name
self.model_name = model_name
self.preset = preset
return success, response, user_message
[docs]
def build_request_config(self, config):
config = self.prepare_config(config)
success, provider, user_message = self.load_provider(config)
if not success:
return success, provider, user_message
config = self.merge_preset_overrides(config)
preset = (config["metadata"], config["customizations"])
customizations, tools, tool_choice = self.expand_tools(config["customizations"])
config["customizations"] = customizations
llm = provider.make_llm(
config["customizations"], tools=tools, tool_choice=tool_choice, use_defaults=True
)
preset_name = config["metadata"].get("name", "")
model_name = getattr(llm, provider.model_property_name)
success, response, user_message = provider.validate_model(model_name)
if not success:
return success, response, user_message
token_manager = TokenManager(self.config, provider, model_name, self.tool_cache)
message = f"Built LLM based on preset_name: {preset_name or 'None'}, metadata: {config['metadata']}, customizations: {config['customizations']}, preset_overrides: {config['preset_overrides']}"
self.log.debug(message)
return True, (provider, preset, llm, preset_name, model_name, token_manager), message
[docs]
def prepare_config(self, config):
config["metadata"] = copy.deepcopy(config["metadata"] or {})
config["customizations"] = copy.deepcopy(config["customizations"] or {})
config["preset_overrides"] = config["preset_overrides"] or {}
return config
[docs]
def load_provider(self, config):
if "provider" in config["metadata"]:
return self.provider_manager.load_provider(config["metadata"]["provider"])
return True, self.provider, "Default provider loaded"
[docs]
def merge_preset_overrides(self, config):
if config["preset_overrides"]:
if "metadata" in config["preset_overrides"]:
self.log.info(
f"Merging preset overrides for metadata: {config['preset_overrides']['metadata']}"
)
config["metadata"] = util.merge_dicts(
config["metadata"], config["preset_overrides"]["metadata"]
)
if "model_customizations" in config["preset_overrides"]:
self.log.info(
f"Merging preset overrides for model customizations: {config['preset_overrides']['model_customizations']}"
)
config["customizations"] = util.merge_dicts(
config["customizations"], config["preset_overrides"]["model_customizations"]
)
return config
[docs]
def extract_metadata_customizations(self):
self.log.debug(
f"Extracting preset configuration from request_overrides: {self.request_overrides}"
)
success, response, user_message = util.extract_preset_configuration_from_request_overrides(
self.request_overrides, self.default_preset_name
)
if not success:
return success, response, user_message
preset_name, preset_overrides, _activate_preset = response
metadata = {}
customizations = {}
if preset_name:
self.log.debug(f"Preset {preset_name!r} extracted from request overrides")
success, response, user_message = self.get_preset_metadata_customizations(preset_name)
if success:
metadata, customizations = response
else:
return success, response, user_message
else:
if self.default_preset:
self.log.debug("Using default preset")
metadata, _ = self.default_preset
else:
self.log.debug("Using current provider")
metadata["provider"] = self.provider.name
customizations = self.provider.get_customizations()
return (
success,
(preset_name, preset_overrides, metadata, customizations),
"Extracted metadata and customizations",
)
[docs]
def get_preset_metadata_customizations(self, preset_name):
success, preset, user_message = self.preset_manager.ensure_preset(preset_name)
if not success:
return success, preset, user_message
metadata, customizations = preset
self.log.debug(
f"Retrieved metadata and customizations for preset: {preset_name}, metadata: {metadata}, customizations: {customizations}"
)
return (
success,
(metadata, customizations),
f"Retrieved metadata and customizations for preset: {preset_name}",
)
[docs]
def expand_tools(self, customizations):
"""Expand any configured tools to their full definition.
:param customizations: Model customizations
:type customizations: dict
:returns: customizations, tools, tool_choice
:rtype: tuple
"""
customizations = copy.deepcopy(customizations)
self.tool_cache = ToolCache(self.config, self.tool_manager, customizations)
self.tool_cache.add_message_tools(self.old_messages)
tools = [
self.tool_manager.get_tool_config(tool_name) for tool_name in self.tool_cache.tools
]
if "tools" in customizations:
del customizations["tools"]
tool_choice = customizations.pop("tool_choice", None)
return customizations, tools, tool_choice
[docs]
def prepare_default_new_conversation_messages(self):
"""
Prepare default new conversation messages.
:returns: List of new messages
:rtype: list
"""
new_messages = []
if len(self.old_messages) == 0:
new_messages.append(
self.message.build_message(
"system", self.request_overrides.get("system_message") or self.system_message
)
)
new_messages.append(self.message.build_message("user", self.input))
return new_messages
[docs]
def prepare_custom_new_conversation_messages(self):
"""
Prepare custom new conversation messages.
:returns: List of new messages
:rtype: list
"""
return [
self.message.build_message(message["role"], message["content"])
for message in self.input
]
[docs]
def prepare_new_conversation_messages(self):
"""
Prepare new conversation messages.
:returns: List of new messages
:rtype: list
"""
return (
self.prepare_custom_new_conversation_messages()
if type(self.input) is list
else self.prepare_default_new_conversation_messages()
)
[docs]
def prepare_ask_request(self):
"""
Prepare the request for the LLM.
:returns: New messages, messages
:rtype: tuple
"""
new_messages = self.prepare_new_conversation_messages()
old_messages = self.tool_cache.add_message_tools(self.old_messages)
messages = old_messages + new_messages
messages = self.strip_out_messages_over_max_tokens(messages, self.max_submission_tokens)
return new_messages, messages
[docs]
def strip_out_messages_over_max_tokens(self, messages, max_tokens):
"""
Recursively strip out messages over max tokens.
:param messages: Messages
:type messages: list
:param max_tokens: Max tokens
:type max_tokens: int
:returns: Messages
:rtype: list
"""
messages = copy.deepcopy(messages)
token_count = self.token_manager.get_num_tokens_from_messages(messages)
self.log.debug(
f"Stripping messages over max tokens: {max_tokens}, initial token count: {token_count}"
)
stripped_messages_count = 0
while token_count > max_tokens and len(messages) > 1:
message = messages.pop(0)
token_count = self.token_manager.get_num_tokens_from_messages(messages)
self.log.debug(
f"Stripping message: {message['role']}, {message['message']} -- new token count: {token_count}"
)
stripped_messages_count += 1
token_count = self.token_manager.get_num_tokens_from_messages(messages)
if token_count > max_tokens:
raise Exception(
f"No messages to send, all messages have been stripped, still over max submission tokens: {max_tokens}"
)
if stripped_messages_count > 0:
max_tokens_exceeded_warning = f"Conversation exceeded max submission tokens ({max_tokens}), stripped out {stripped_messages_count} oldest messages before sending, sent {token_count} tokens instead"
self.log.warning(max_tokens_exceeded_warning)
util.print_status_message(False, max_tokens_exceeded_warning)
return messages
[docs]
def call_llm(self, messages):
"""
Call the LLM.
:param messages: Messages
:type messages: list
:returns: success, response, message
:rtype: tuple
"""
stream = self.request_overrides.get("stream", False)
self.log.debug(f"Calling LLM with message count: {len(messages)}")
llm_pre_call_method = getattr(self.provider, "llm_pre_call", None)
if llm_pre_call_method:
messages = llm_pre_call_method(self.llm, messages)
messages = self.build_chat_request(messages)
if stream:
return self.execute_llm_streaming(messages)
else:
return self.execute_llm_non_streaming(messages)
[docs]
def build_chat_request(self, messages):
"""
Build chat request for LLM.
:param messages: Messages
:type messages: list
:returns: Prepared messages
:rtype: list
"""
self.log.debug(f"Building messages for LLM, message count: {len(messages)}")
messages = util.transform_messages_to_chat_messages(messages)
messages = self.provider.prepare_messages_for_llm(messages)
messages = self.attach_files(messages)
return messages
[docs]
def attach_files(self, messages):
files = self.request_overrides.get("files", [])
if files:
for file in files:
file_data = self.provider.prepare_file_for_llm(file)
messages.append(file_data)
return messages
[docs]
def output_chunk_content(self, content, print_stream, stream_callback):
if content:
if print_stream:
print(content, end="", flush=True)
if stream_callback:
stream_callback(content)
[docs]
def iterate_streaming_response(self, messages, print_stream, stream_callback):
response = None
previous_chunks = []
self.log.debug(f"Streaming with LLM attributes: {self.llm.dict()}")
provider_streaming_method = getattr(self.provider, "handle_streaming_chunk", None)
for chunk in self.llm.stream(messages):
if provider_streaming_method:
content = provider_streaming_method(chunk, previous_chunks)
response = content if not response else response + content
elif isinstance(chunk, AIMessageChunk) or isinstance(chunk, AIMessage):
content = chunk.content
response = chunk if not response else response + chunk
elif isinstance(chunk, str):
content = chunk
response = content if not response else response + content
else:
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
self.output_chunk_content(content, print_stream, stream_callback)
if not self.streaming:
if getattr(response, "tool_call_chunks", None):
response = None
util.print_status_message(False, "Generation stopped")
break
previous_chunks.append(chunk)
return response
[docs]
def execute_llm_streaming(self, messages):
self.log.debug(f"Started streaming request at {util.current_datetime().isoformat()}")
response = ""
print_stream = self.request_overrides.get("print_stream", False)
stream_callback = self.request_overrides.get("stream_callback", None)
# Start streaming loop.
self.streaming = True
try:
response = self.iterate_streaming_response(messages, print_stream, stream_callback)
except ValueError as e:
return False, messages, e
finally:
# End streaming loop.
self.streaming = False
self.log.debug(f"Stopped streaming response at {util.current_datetime().isoformat()}")
return True, response, "Response received"
[docs]
def execute_llm_non_streaming(self, messages):
self.log.info("Starting non-streaming request")
self.log.debug(f"Non-streaming with LLM attributes: {self.llm.dict()}")
provider_non_streaming_method = getattr(self.provider, "handle_non_streaming_response", None)
try:
response = self.llm.invoke(messages)
if provider_non_streaming_method:
response = provider_non_streaming_method(response)
except ValueError as e:
return False, messages, e
return True, response, "Response received"
[docs]
def post_response(self, response_obj, new_messages):
response_message, tool_calls = self.extract_message_content(response_obj)
new_messages.append(response_message)
if tool_calls:
return self.handle_tool_calls(tool_calls, new_messages)
return self.handle_non_tool_response(response_message, new_messages)
[docs]
def handle_tool_calls(self, tool_calls, new_messages):
for tool_call in tool_calls:
self.log_tool_call(tool_call)
if self.should_return_on_tool_call():
names = [tool_call["name"] for tool_call in tool_calls]
self.log.info(f"Returning directly on tool call: {names}")
return tool_calls, new_messages
return self.execute_tool_calls(tool_calls, new_messages)
[docs]
def handle_non_tool_response(self, response_message, new_messages):
tool_response, new_messages = self.check_return_on_tool_response(new_messages)
if tool_response:
self.log.info("Returning directly on tool response")
return tool_response, new_messages
return response_message["message"], new_messages
[docs]
def log_tool_call(self, tool_call):
if not self.return_only:
util.print_markdown(
f"### AI requested tool call:\n* Name: {tool_call['name']}\n* Arguments: {tool_call['args']}"
)
[docs]
def execute_tool_call(self, tool_call):
success, tool_response, user_message = self.run_tool(tool_call["name"], tool_call["args"])
if not success:
raise ValueError(f"Tool call failed: {user_message}")
return tool_response
[docs]
def execute_tool_calls(self, tool_calls, new_messages):
for tool_call in tool_calls:
tool_response = self.execute_tool_call(tool_call)
new_messages.append(self.build_tool_response_message(tool_call, tool_response))
# If a tool call is forced, we cannot recurse, as there will
# never be a final non-tool response, and we'll recurse infinitely.
# TODO: Perhaps in the future we can handle this more elegantly by:
# 1. Tracking which tools with which arguments are called, and breaking
# on the first duplicate call.
# 2. Allowing a 'maximum_forced_tool_calls' metadata attribute.
# 3. Automatically switching the preset's 'tool_choice' to 'auto' after
# the first call.
if self.check_forced_tool():
self.log.debug("Returning directly on forced tool call")
return tool_response, new_messages
success, response_obj, user_message = self.call_llm(new_messages)
if not success:
raise ValueError(f"LLM call failed: {user_message}")
return self.post_response(response_obj, new_messages)
[docs]
def build_tool_response_message(self, tool_call, tool_response):
message_metadata = {
"name": tool_call["name"],
}
if "id" in tool_call:
message_metadata["id"] = tool_call["id"]
return self.message.build_message(
"tool",
tool_response,
message_type="tool_response",
message_metadata=message_metadata,
)
[docs]
def extract_message_content(self, message):
"""
Extract the content from an LLM message.
:param message: Message
:type message: dict | BaseMessage
:returns: Built message, tool calls
:rtype: tuple
"""
tool_calls = []
if isinstance(message, BaseMessage):
tool_calls = message.tool_calls
invalid_tool_calls = getattr(message, "invalid_tool_calls", [])
if invalid_tool_calls:
tool_call_errors = ", ".join(
[
f"{tool_call['name']}: {tool_call['error']}"
for tool_call in invalid_tool_calls
]
)
raise RuntimeError(f"LLM tool call failed: {tool_call_errors}")
message_dict = convert_message_to_dict(message)
content = message_dict["content"]
message_type = "content"
if tool_calls:
message_type = "tool_call"
content = tool_calls
return (
self.message.build_message(message_dict["role"], content, message_type),
tool_calls,
)
return self.message.build_message("assistant", message), tool_calls
[docs]
def should_return_on_tool_call(self):
"""
Check if should return on tool call.
:returns: Whether to return on tool call
:rtype: bool
"""
metadata, _customizations = self.preset
return "return_on_tool_call" in metadata and metadata["return_on_tool_call"]
[docs]
def check_forced_tool(self):
"""Check if a tool call is forced.
:returns: True if forced tool
:rtype: bool
"""
_metadata, customizations = self.preset
if "tool_choice" in customizations:
return not (
isinstance(customizations["tool_choice"], str)
and customizations["tool_choice"] in ["auto", "none"]
)
return False
[docs]
def check_return_on_tool_response(self, new_messages):
"""
Check for return on tool response.
Supports multiple tool calls.
:param new_messages: List of new messages
:type new_messages: list
:returns: Tool response or None, updated messages
:rtype: tuple
"""
metadata, _customizations = self.preset
if "return_on_tool_response" in metadata and metadata["return_on_tool_response"]:
# NOTE: In order to allow for multiple tool calling and
# returning on the LAST tool response, we need to allow
# the LLM to respond to all previous tool responses, as
# it may respond with another tool call.
#
# Thus, at the end of all responses from the LLM, the last
# message will be a natural language reponse, and the previous
# message will be the last tool response.
#
# To correctly return the tool response and message list
# we need to:
# 1. Remove the last message
# 2. Extract and return the tool response
if self.is_tool_response_message(new_messages[-2]):
new_messages.pop()
tool_response = new_messages[-1]["message"]
return tool_response, new_messages
return None, new_messages
[docs]
def run_tool(self, tool_name, data):
"""Run a tool.
:param tool_name: Tool name
:type tool_name: str
:param data: Tool arguments
:type data: dict
:returns: success, response, message
:rtype: tuple
"""
success, response, user_message = self.tool_manager.run_tool(tool_name, data)
json_obj = response if success else {"error": user_message}
if not self.return_only:
util.print_markdown(f"### Tool response:\n* Name: {tool_name}\n* Success: {success}")
util.print_markdown(json_obj)
return success, json_obj, user_message
[docs]
def is_tool_response_message(self, message):
"""Check if a message is a tool response.
:param message: The message
:type message: dict
:returns: True if tool response
:rtype: bool
"""
return message["message_type"] == "tool_response"
[docs]
def terminate_stream(self, _signal, _frame):
"""
Handles termination signal and stops the stream if it's running.
"""
self.log.info("Received signal to terminate stream")
if self.streaming:
self.streaming = False