class KimiAudioTokenizer(TokenizerLike):
"""TikToken tokenizer for Kimi-Audio."""
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "KimiAudioTokenizer":
if args:
logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.")
path = Path(path_or_repo_id)
if path.is_file():
vocab_file = path
elif path.is_dir():
vocab_file = path / "tiktoken.model"
if not vocab_file.is_file():
vocab_file = path / "tokenizer.model"
else:
# Download from HuggingFace Hub
repo_id = str(path_or_repo_id)
# Try to download tiktoken.model or tokenizer.model
try:
vocab_path = hf_hub_download(
repo_id=repo_id,
filename="tiktoken.model",
revision=revision,
local_dir=download_dir,
)
vocab_file = Path(vocab_path)
except Exception:
try:
vocab_path = hf_hub_download(
repo_id=repo_id,
filename="tokenizer.model",
revision=revision,
local_dir=download_dir,
)
vocab_file = Path(vocab_path)
except Exception as exc:
raise ValueError(
f"Could not find tiktoken.model or tokenizer.model in {repo_id}"
) from exc
# Also download tokenizer_config.json if available
with contextlib.suppress(Exception):
hf_hub_download(
repo_id=repo_id,
filename="tokenizer_config.json",
revision=revision,
local_dir=download_dir,
)
if not vocab_file.is_file():
raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.")
return cls(
vocab_file=vocab_file,
name_or_path=str(path_or_repo_id),
truncation_side=kwargs.get("truncation_side", "left"),
)
def __init__(
self,
*,
vocab_file: Path,
name_or_path: str,
truncation_side: str,
) -> None:
super().__init__()
self.name_or_path = name_or_path
self._truncation_side = truncation_side
self._vocab_file = vocab_file
# Load special tokens from tokenizer_config.json
special_tokens: dict[str, int] = {}
tokenizer_config = vocab_file.parent / "tokenizer_config.json"
if tokenizer_config.is_file():
with open(tokenizer_config, encoding="utf-8") as f:
config = json.load(f)
# Extract special tokens from added_tokens_decoder
added_tokens = config.get("added_tokens_decoder", {})
for token_id_str, token_info in added_tokens.items():
token_id = int(token_id_str)
content = token_info.get("content", "")
if content:
special_tokens[content] = token_id
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(
vocab_file, special_tokens
)
# Build token <-> ID mappings
self._token_to_id: dict[str, int] = {}
self._id_to_token: dict[int, str] = {}
for token_bytes, token_id in self._tokenizer._mergeable_ranks.items():
token_str = token_bytes.decode("utf-8", errors="replace")
self._token_to_id[token_str] = token_id
self._id_to_token[token_id] = token_str
# Initialize added_tokens_decoder before adding special tokens
self._added_tokens_decoder: dict[int, Any] = {}
# Add Kimi-Audio special tokens
self._add_kimiaudio_special_tokens()
# Set default special token IDs (will be updated when special tokens are added)
self._bos_token_id = 151643 # Kimi-Audio BOS
self._eos_token_id = 151644 # Kimi-Audio EOS
self._pad_token_id = self._eos_token_id
self._unk_token_id = self._pad_token_id
self._max_chars_per_token = max(
(len(tok) for tok in self._token_to_id), default=10
)
def _add_kimiaudio_special_tokens(self) -> None:
"""Add Kimi-Audio special tokens to the tokenizer."""
# Tokens should already be in self._special_tokens from tokenizer_config.json
# Just add them to added_tokens_decoder for compatibility
kimiaudio_special_tokens = {
"<|im_media_begin|>": 151661,
"<|im_media_end|>": 151663,
"<|im_kimia_text_blank|>": 151666,
"<|im_msg_end|>": 151645,
"<|im_kimia_user_msg_start|>": 151670,
"<|im_kimia_assistant_msg_start|>": 151671,
}
for token_str, token_id in kimiaudio_special_tokens.items():
# Only add if not already present
if token_id not in self._added_tokens_decoder:
self._added_tokens_decoder[token_id] = AddedToken(
token_str, single_word=True, normalized=False, special=True
)
# Also ensure it's in _token_to_id and _id_to_token
if token_str not in self._token_to_id:
self._token_to_id[token_str] = token_id
if token_id not in self._id_to_token:
self._id_to_token[token_id] = token_str
def num_special_tokens_to_add(self) -> int:
return 0
@property
def all_special_tokens(self) -> list[str]:
return list(self._added_tokens_decoder.values())
@property
def all_special_ids(self) -> list[int]:
return list(self._added_tokens_decoder.keys())
@property
def bos_token_id(self) -> int:
return self._bos_token_id
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def pad_token_id(self) -> int:
return self._pad_token_id
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
return self._tokenizer.n_vocab
@property
def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property
def truncation_side(self) -> str:
return self._truncation_side
@property
def added_tokens_decoder(self) -> dict[int, Any]:
return self._added_tokens_decoder
@added_tokens_decoder.setter
def added_tokens_decoder(self, value: dict[int, Any]) -> None:
"""Set added tokens decoder and update special token IDs."""
self._added_tokens_decoder = value
# Update special token IDs if known tokens are added
for token_id, token in value.items():
token_str = str(token) if hasattr(token, "__str__") else token
if "<|im_kimia_user_msg_start|>" in token_str:
self._bos_token_id = token_id
elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str:
self._eos_token_id = token_id
def get_vocab(self) -> dict[str, int]:
return dict(self._token_to_id)
def __len__(self) -> int:
"""Return vocab size for compatibility with HF tokenizer interface."""
return self._tokenizer.n_vocab
def get_added_vocab(self) -> dict[str, int]:
return {
str(token): token_id
for token_id, token in self._added_tokens_decoder.items()
}
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
if max_length is None or len(tokens) <= max_length:
return tokens
if self.truncation_side == "left":
return tokens[-max_length:]
return tokens[:max_length]
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool = True,
**kwargs,
) -> list[int]:
del add_special_tokens
# Allow Kimi-Audio special tokens to be encoded
tokens = self._tokenizer.encode(
text,
allowed_special={
"<|im_media_begin|>",
"<|im_media_end|>",
"<|im_kimia_text_blank|>",
"<|im_msg_end|>",
"<|im_kimia_user_msg_start|>",
"<|im_kimia_assistant_msg_start|>",
},
)
if truncation:
tokens = self._maybe_truncate(tokens, max_length)
return tokens
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
"""Decode token IDs to text, optionally skipping special tokens."""
if isinstance(ids, int):
ids = [ids]
if skip_special_tokens:
# Skip tokens that are in special_tokens (loaded from config)
special_ids = set(self._special_tokens.values())
ids = [token_id for token_id in ids if token_id not in special_ids]
return self._tokenizer.decode(ids)
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
if isinstance(tokens, str):
return self._token_to_id.get(tokens, self._unk_token_id)
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
def convert_ids_to_tokens(
self, ids: list[int], skip_special_tokens: bool = False
) -> list[str]:
tokens = []
for token_id in ids:
if skip_special_tokens and token_id in self._added_tokens_decoder:
continue
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
return tokens
def convert_tokens_to_string(self, tokens: list[str]) -> str:
token_ids = self.convert_tokens_to_ids(tokens)
return self.decode(token_ids, skip_special_tokens=False)
def __call__(
self,
text: str | list[str],
text_pair: str | None = None,
add_special_tokens: bool = True,
truncation: bool = False,
max_length: int | None = None,
**kwargs,
) -> BatchEncoding:
if text_pair is not None:
raise NotImplementedError(
"text_pair is not supported for KimiAudioTokenizer."
)
if isinstance(text, list):
input_ids_batch: list[list[int]] = [
self.encode(
item,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
for item in text
]
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
return BatchEncoding(
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
)
input_ids = self.encode(
text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
attention_mask = [1] * len(input_ids)
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
def get_chat_template(
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
) -> str | None:
del tools
return chat_template
def apply_chat_template(
self,
messages: list[ChatCompletionMessageParam] | None = None,
tools: list[dict[str, Any]] | None = None,
chat_template: str | None = None,
tokenize: bool = False,
**kwargs,
) -> str | list[int]:
# Handle both 'messages' (protocol) and 'conversation' (caller) parameter names
conversation = messages if messages is not None else kwargs.get("conversation")
if conversation is None:
raise ValueError("Either 'messages' or 'conversation' must be provided.")
template = self.get_chat_template(chat_template, tools=tools)
if template is None:
raise ValueError(
"No chat template available. Provide `chat_template` explicitly."
)
# Use render_jinja_template instead of apply_chat_template
# Note: render_jinja_template returns ([prompts], [generation_indices])
rendered, _ = hf_chat_utils.render_jinja_template(
conversation,
chat_template=template,
tools=tools,
**kwargs,
)
# Extract the first (and usually only) prompt
prompt = rendered[0] if rendered else ""
if tokenize:
return self.encode(prompt, add_special_tokens=False)
return prompt