Source code for lwe.plugins.provider_fake_llm

from typing import Any, Iterator, List, Optional, Union

from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import ChatGenerationChunk

# TODO: Re-enable if https://github.com/langchain-ai/langchain/pull/10200 lands.
# from langchain.chat_models.fake import FakeMessagesListChatModel

from lwe.core.provider import Provider, PresetValue
from lwe.core import constants

# TODO: Remove these definitions if https://github.com/langchain-ai/langchain/pull/10200 lands.
import asyncio
import time
from typing import cast, AsyncIterator, Dict
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatResult

# TODO: Remove these definitions if https://github.com/langchain-ai/langchain/pull/10200 lands.

DEFAULT_RESPONSE_MESSAGE = "test response"


[docs] class FakeMessagesListChatModel(BaseChatModel): """Fake ChatModel for testing purposes.""" responses: Union[ List[Union[BaseMessage, BaseMessageChunk, str]], List[List[Union[BaseMessage, BaseMessageChunk, str]]], ] sleep: Optional[float] = None i: int = 0 @property def _llm_type(self) -> str: return "fake-messages-list-chat-model" @property def _identifying_params(self) -> Dict[str, Any]: return {"responses": self.responses} def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) responses = response if isinstance(response, list) else [response] generations = [ChatGeneration(message=res) for res in responses] return ChatResult(generations=generations) def _call( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Union[BaseMessage, List[BaseMessage]]: """Rotate through responses.""" response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 if isinstance(response, str): response = AIMessage(content=response) elif isinstance(response, BaseMessage): pass elif isinstance(response, list): for i, item in enumerate(response): if isinstance(item, str): response[i] = AIMessage(content=item) elif not isinstance(item, BaseMessage): raise TypeError(f"Unexpected type in response list: {type(item)}") else: raise TypeError(f"Unexpected type for response: {type(response)}") if isinstance(response, BaseMessage): return response elif isinstance(response, list) and all(isinstance(item, BaseMessage) for item in response): return cast(List[BaseMessage], response) else: raise TypeError("Unexpected type after processing response") def _stream( self, messages: List[BaseMessage], stop: Union[List[str], None] = None, run_manager: Union[CallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Rotate through responses.""" response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 for c in response: if self.sleep is not None: time.sleep(self.sleep) if isinstance(c, AIMessageChunk): chunk = c elif isinstance(c, str): chunk = AIMessageChunk(content=c) else: raise TypeError(f"Unexpected type for response chunk: {type(c)}") yield ChatGenerationChunk(message=chunk) async def _astream( self, messages: List[BaseMessage], stop: Union[List[str], None] = None, run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: """Rotate through responses.""" response = self.responses[self.i] if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 for c in response: if self.sleep is not None: await asyncio.sleep(self.sleep) if isinstance(c, AIMessageChunk): chunk = c elif isinstance(c, str): chunk = AIMessageChunk(content=c) else: raise TypeError(f"Unexpected type for response chunk: {type(c)}") yield ChatGenerationChunk(message=chunk)
[docs] class CustomFakeMessagesListChatModel(FakeMessagesListChatModel): model_name: str def __init__(self, **kwargs): if not kwargs.get("model_name"): kwargs["model_name"] = constants.API_BACKEND_DEFAULT_MODEL if not kwargs.get("responses"): kwargs["responses"] = [] super().__init__(**kwargs) def _call( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> BaseMessage: if not self.responses: self.responses = [ AIMessage(content=DEFAULT_RESPONSE_MESSAGE), ] return super()._call(messages, stop, run_manager, **kwargs) def _stream( self, messages: List[BaseMessage], stop: Union[List[str], None] = None, run_manager: Union[CallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: if not self.responses: self.responses = [ [ AIMessageChunk(content=DEFAULT_RESPONSE_MESSAGE), ], ] return super()._stream(messages, stop, run_manager, **kwargs)
DEFAULT_CAPABILITIES = { "chat": True, "validate_models": False, "models": { "gpt-3.5-turbo": { "max_tokens": 4096, }, "gpt-3.5-turbo-1106": { "max_tokens": 16384, }, "gpt-4": { "max_tokens": 8192, }, "gpt-4o": { "max_tokens": 131072, }, }, }
[docs] class ProviderFakeLlm(Provider): """ Fake LLM provider. """ def __init__(self, config=None, **kwargs): self._capabilities = DEFAULT_CAPABILITIES super().__init__(config, **kwargs) @property def capabilities(self): return self._capabilities @capabilities.setter def capabilities(self, value): self._capabilities = value @property def default_model(self): return constants.API_BACKEND_DEFAULT_MODEL
[docs] def prepare_messages_method(self): return self.prepare_messages_for_llm_chat
[docs] def llm_factory(self): return CustomFakeMessagesListChatModel
[docs] def customization_config(self): return { "responses": None, "model_name": PresetValue(str, options=self.available_models), }