Source code for lwe.core.token_manager
import json
import tiktoken
from lwe.core.config import Config
from lwe.core.logger import Logger
from lwe.core import util
[docs]
class TokenManager:
"""Manage model tokens."""
def __init__(self, config, provider, model_name, tool_cache):
self.config = config or Config()
self.log = Logger(self.__class__.__name__, self.config)
self.provider = provider
self.model_name = model_name
self.tool_cache = tool_cache
[docs]
def get_token_encoding(self):
"""
Get token encoding for a model.
:raises NotImplementedError: If unsupported model
:raises Exception: If error getting encoding
:returns: Encoding object
:rtype: Encoding
"""
validate_models = self.provider.get_capability("validate_models", True)
if validate_models and self.model_name not in self.provider.available_models:
raise NotImplementedError(f"Unsupported model: {self.model_name}")
try:
encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
except Exception as err:
raise Exception(
f"Unable to get token encoding for model {self.model_name}: {str(err)}"
) from err
return encoding
[docs]
def get_num_tokens_from_messages(self, messages, encoding=None):
"""
Get number of tokens for a list of messages.
If a provider does not have a get num_tokens_from_messages() method,
default_get_num_tokens_from_messages() will be used.
:param messages: List of messages
:type messages: list
:param encoding: Encoding to use, defaults to None to auto-detect
:type encoding: Encoding, optional
:returns: Number of tokens
:rtype: int
"""
token_counter = getattr(self.provider, "get_num_tokens_from_messages", None)
return (
token_counter(messages, encoding)
if token_counter
else self.default_get_num_tokens_from_messages(messages, encoding)
)
[docs]
def default_get_num_tokens_from_messages(self, messages, encoding=None):
"""
Get number of tokens for a list of messages.
The default implementation uses tiktoken, which is the OpenAI implementation.
:param messages: List of messages
:type messages: list
:param encoding: Encoding to use, defaults to None to auto-detect
:type encoding: Encoding, optional
:returns: Number of tokens
:rtype: int
"""
if not encoding:
encoding = self.get_token_encoding()
num_tokens = 0
messages = self.tool_cache.add_message_tools(messages)
messages = util.transform_messages_to_chat_messages(messages)
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
if isinstance(value, dict) or isinstance(value, list):
value = json.dumps(value, indent=2)
if value:
num_tokens += len(encoding.encode(str(value)))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
if len(self.tool_cache.tools) > 0:
tools = [
self.tool_cache.tool_manager.get_tool_config(tool_name)
for tool_name in self.tool_cache.tools
]
tools_string = json.dumps(tools, indent=2)
num_tokens += len(encoding.encode(tools_string))
return num_tokens