from abc import abstractmethod
import copy
from typing import (
Any,
Mapping,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from lwe.core.plugin import Plugin
from lwe.core import constants
from lwe.core import util
[docs]
class PresetValue:
def __init__(
self,
value_type,
min_value=None,
max_value=None,
options=None,
include_none=False,
private=False,
):
options = options or {}
self.value_type = value_type
self.min_value = min_value
self.max_value = max_value
self.options = options
self.include_none = include_none
self.private = private
self.completions = {}
self.build_completions()
[docs]
def build_completions(self):
if self.value_type == bool:
self.completions = {
"true": None,
"false": None,
}
elif self.value_type == int:
if self.min_value is not None and self.max_value is not None:
self.completions = util.list_to_completion_hash(
range(self.min_value, self.max_value + 1)
)
elif self.value_type == float:
if self.min_value is not None and self.max_value is not None:
self.completions = util.float_range_to_completions(self.min_value, self.max_value)
elif self.value_type == str:
self.completions = util.list_to_completion_hash(self.options)
elif self.value_type == dict:
pass
else:
raise ValueError(
"Invalid value type provided. Must be one of the following: str, int, float, bool"
)
if self.value_type != dict and self.include_none:
self.completions["None"] = None
[docs]
def cast(self, value):
if value == "None" or value is None:
return True, None, None
if self.value_type == bool:
if isinstance(value, bool):
return True, value, None
elif value.lower() in ["true", "t", "1"]:
return True, True, None
elif value.lower() in ["false", "f", "0"]:
return True, False, None
else:
return (
False,
None,
"Invalid value provided. Must be one of the following: true, false",
)
elif self.value_type == int:
try:
return True, int(value), None
except ValueError:
return False, None, "Invalid value provided. Must be an integer"
elif self.value_type == float:
try:
return True, float(value), None
except ValueError:
return False, None, "Invalid value provided. Must be a float"
elif self.value_type == str:
return True, str(value), None
else:
return (
False,
None,
"Invalid value type provided. Must be one of the following: str, int, float, bool",
)
[docs]
class ProviderBase(Plugin):
@property
def display_name(self):
return self.name[len(constants.PROVIDER_PREFIX) :]
@property
def plugin_type(self):
return "provider"
@property
def model_property_name(self):
return "model_name"
@property
def available_models(self):
return self.get_capability("models", {}).keys()
[docs]
def setup(self):
self.load_models()
self.set_customizations(self.default_customizations())
[docs]
def default_config(self):
return {}
[docs]
def load_models(self):
self.models = self.config.get(f"plugins.{self.name}.models")
if not self.models:
if hasattr(self, "fetch_models"):
try:
success, data, _user_message = self.cache_manager.cache_get(
self.plugin_cache_filename
)
if success:
self.models = data["models"]
else:
message = f"Fetching data for {self.name}\nData can be refreshed by running '{constants.COMMAND_LEADER}plugin reload {self.name}'"
self.log.info(message)
util.print_status_message(True, message)
self.models = self.fetch_models()
self.write_plugin_cache_file({"models": self.models})
except Exception as e:
message = f"Could not fetch data for {self.name}: {e}"
self.log.error(message)
util.print_status_message(False, message)
self.models = copy.deepcopy(self.static_models)
else:
self.models = copy.deepcopy(self.static_models)
@property
def static_models(self):
return {}
[docs]
def default_customizations(self, defaults=None):
defaults = defaults or {}
llm_class = self.llm_factory()
llm = llm_class(**defaults)
llm_defaults = llm.dict()
custom_config = self.customization_config()
defaults = {k: v for k, v in llm_defaults.items() if k in custom_config}
if self.default_model:
defaults[self.model_property_name] = self.default_model
defaults["_type"] = llm_defaults["_type"]
return defaults
# NOTE: This is a best-guess approach for random dict values.
[docs]
def cast_dict_value(self, value):
if isinstance(value, str):
try:
value = int(value)
except ValueError:
try:
value = float(value)
except ValueError:
pass
return value
[docs]
def calculate_customization_value(self, orig_keys, new_value):
if isinstance(orig_keys, str):
orig_keys = orig_keys.split(".")
keys = orig_keys.copy()
config = self.customization_config()
while keys:
key = keys.pop(0)
if key in config:
config = config[key]
if config is None:
return True, new_value, "Passing through value."
elif config == dict:
return True, self.cast_dict_value(new_value), "Found dict key."
# Special case: some LLM classes return an unset key for a nested config by default
# so pass through this unset key.
elif len(keys) == 0 and isinstance(config, dict):
return True, None, f"Returning None for unconfigured dict '{key}'."
elif isinstance(config, PresetValue):
success, new_value, user_message = config.cast(new_value)
if success:
return True, new_value, "Found preset value."
return False, None, user_message
else:
break
return False, None, f"Invalid key {key}."
[docs]
def get_customization_value(self, keys):
if isinstance(keys, str):
keys = keys.split(".")
customizations = self.customizations.copy()
while keys:
key = keys.pop(0)
if key in customizations:
customizations = customizations[key]
if not isinstance(customizations, dict):
return True, customizations, f"Found key {key}"
# Special case: some LLM classes return an unset key for a nested config by default
# so pass through this unset key.
elif len(keys) == 0:
return True, None, f"Returning None for unconfigured dict '{key}'."
return False, None, f"Invalid key {key}."
[docs]
def set_customization_value(self, keys, new_value):
if isinstance(keys, str):
keys = keys.split(".")
if isinstance(new_value, dict):
for k, v in new_value.items():
success, new_value, user_message = self.set_customization_value(keys + [k], v)
if not success:
return success, new_value, user_message
return success, new_value, user_message
else:
success, new_value, user_message = self.calculate_customization_value(keys, new_value)
if success:
self.set_value(keys, new_value)
return success, self.customizations, f"Set {'.'.join(keys)} to {new_value}"
else:
return success, new_value, user_message
[docs]
def set_value(self, keys, value):
customizations = self.customizations
for key in keys[:-1]:
if customizations.get(key) is None:
customizations[key] = {}
customizations = customizations.setdefault(key, {})
customizations[keys[-1]] = value
[docs]
def get_customizations(self, customizations=None):
customizations = self.customizations if customizations is None else customizations
# _type key exists in the Langchain output of dict(llm), and must be filtered out.
customizations = {k: v for k, v in customizations.items() if k != "_type"}
return customizations
[docs]
def set_customizations(self, customizations):
self.customizations = customizations
[docs]
def customizations_to_completions(self):
def dict_to_completions(completions, items, prefix=None, is_dict=False):
prefix = prefix or []
for key, value in items.items():
full_key = ".".join(prefix + [key])
if value is None:
continue
elif value == dict:
completions[full_key] = {}
elif isinstance(value, dict):
new_prefix = prefix + [key]
for k, v in value.items():
dict_key = "%s.%s" % (full_key, k)
if isinstance(v, PresetValue):
completions[dict_key] = v.completions
else:
completions[dict_key] = None
dict_to_completions(completions, value, new_prefix, is_dict)
elif isinstance(value, PresetValue):
completions[full_key] = value.completions
else:
completions[full_key] = value
return completions
completions = dict_to_completions({}, self.customization_config())
return completions
[docs]
def get_capability(self, capability, default=False):
if capability == "models":
return self.models
return self.capabilities[capability] if capability in self.capabilities else default
[docs]
def get_model(self):
success, model_name, user_message = self.get_customization_value(self.model_property_name)
if success:
return model_name
[docs]
def invalid_model_message(self, model_name):
provider_name = getattr(self, "name", self.__class__.__name__)
return (
f"Invalid model {model_name!r} for {provider_name!r}; model is not present "
f"in the provider model list/cache. Refresh the provider cache with "
f"'/plugin reload {provider_name}' or update the preset."
)
[docs]
def validate_model(self, model_name):
models = self.get_capability("models", {})
validate_models = self.get_capability("validate_models", True)
if not validate_models or model_name in models:
return True, model_name, f"Model {model_name!r} is valid"
return False, None, self.invalid_model_message(model_name)
[docs]
def set_model(self, model_name):
success, _model_name, user_message = self.validate_model(model_name)
if success:
return self.set_customization_value(self.model_property_name, model_name)
return success, None, user_message
[docs]
def make_llm(self, customizations=None, tools=None, tool_choice=None, use_defaults=False):
customizations = customizations or {}
final_customizations = (
self.get_customizations(self.default_customizations())
if use_defaults
else self.get_customizations()
)
final_customizations.update(customizations)
for key in constants.PROVIDER_PRIVATE_CUSTOMIZATION_KEYS:
final_customizations.pop(key, None)
llm_class = self.llm_factory()
llm_pre_init_method = getattr(self, "llm_pre_init", None)
if llm_pre_init_method:
final_customizations = llm_pre_init_method(final_customizations)
llm = llm_class(**final_customizations)
if tools:
self.log.debug(f"Provider {self.display_name} called with tools")
kwargs = {
"tools": self.transform_tools(tools),
}
if tool_choice:
kwargs["tool_choice"] = tool_choice
try:
llm = llm.bind_tools(**kwargs)
except NotImplementedError:
self.log.warning(f"Provider {self.display_name} does not support tools")
return llm
[docs]
def prepare_messages_method(self):
return self.prepare_messages_for_llm_last_message
[docs]
def prepare_messages_for_llm_last_message(self, messages):
messages = messages[-1]["content"]
return messages
[docs]
def prepare_messages_for_llm_stuff_messages(self, messages):
messages = [m["content"] for m in messages]
return "\n\n".join(messages)
[docs]
def prepare_messages_for_llm_chat(self, messages):
messages = [self.convert_dict_to_message(m) for m in messages]
return messages
[docs]
def prepare_messages_for_llm(self, messages):
method = self.prepare_messages_method()
messages = method(messages)
return messages
[docs]
def max_submission_tokens(self):
models = self.get_capability("models", {})
model_name = self.get_model()
if model_name and model_name in models and "max_tokens" in models[model_name]:
return models[model_name]["max_tokens"]
return constants.OPEN_AI_DEFAULT_MAX_SUBMISSION_TOKENS
[docs]
def convert_ai_dict_to_message(self, message: Mapping[str, Any]) -> AIMessage:
"""Convert an LWE message dictionary to a LangChain AIMessage.
This default implementation supports a format suitable for OpenAI.
Other providers plugins may need to override this method depending
on the structure of their AI messages.
"""
content = message.get("content", "")
kwargs = {}
tool_calls = message.get("tool_calls", None)
if tool_calls:
kwargs["tool_calls"] = tool_calls
# NOTE: Remove this if Langchain intetrations ever consistently
# support using AIMessage.tool_calls property.
additional_kwargs = message.get("additional_kwargs", None)
if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs
return AIMessage(content=content, **kwargs)
[docs]
def convert_dict_to_message(self, message: Mapping[str, Any]) -> BaseMessage:
"""Convert an LWE message dictionary to a LangChain message."""
role = message.get("role")
content = message.get("content", "")
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return self.convert_ai_dict_to_message(message)
elif role == "system":
return SystemMessage(content=content)
elif role == "tool":
return ToolMessage(
content=content,
tool_call_id=message.get("tool_call_id"),
name=message.get("name", None),
)
else:
raise ValueError(f"Unknown role: {role}")
[docs]
def prepare_file_for_llm(self, file: dict):
return file
[docs]
class Provider(ProviderBase):
@property
@abstractmethod
def capabilities(self):
pass
@property
@abstractmethod
def default_model(self):
pass
[docs]
@abstractmethod
def llm_factory(self):
pass
[docs]
@abstractmethod
def customization_config(self):
pass