Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: support guided decoding for vllm async engine #2391

Merged
merged 5 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions xinference/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
ChatCompletionStreamOptionsParam,
)
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from openai.types.shared_params.response_format_json_object import (
ResponseFormatJSONObject,
)
from openai.types.shared_params.response_format_text import ResponseFormatText

OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
ChatCompletionStreamOptionsParam
Expand All @@ -70,6 +74,23 @@
)


class JSONSchema(BaseModel):
name: str
description: Optional[str] = None
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = None


class ResponseFormatJSONSchema(BaseModel):
json_schema: JSONSchema
type: Literal["json_schema"]


ResponseFormat = Union[
ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema
]


class CreateChatCompletionOpenAI(BaseModel):
"""
Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py
Expand All @@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel):
n: Optional[int]
parallel_tool_calls: Optional[bool]
presence_penalty: Optional[float]
# we do not support this
# response_format: ResponseFormat
response_format: Optional[ResponseFormat]
seed: Optional[int]
service_tier: Optional[Literal["auto", "default"]]
stop: Union[Optional[str], List[str]]
Expand Down
28 changes: 28 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,9 @@ async def create_completion(self, request: Request) -> Response:
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
kwargs = body.dict(exclude_unset=True, exclude=exclude)

# guided_decoding params
kwargs.update(self.extract_guided_params(raw_body=raw_body))

# TODO: Decide if this default value override is necessary #1061
if body.max_tokens is None:
kwargs["max_tokens"] = max_tokens_field.default
Expand Down Expand Up @@ -1916,9 +1919,13 @@ async def create_chat_completion(self, request: Request) -> Response:
"logit_bias_type",
"user",
}

raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
kwargs = body.dict(exclude_unset=True, exclude=exclude)

# guided_decoding params
kwargs.update(self.extract_guided_params(raw_body=raw_body))

# TODO: Decide if this default value override is necessary #1061
if body.max_tokens is None:
kwargs["max_tokens"] = max_tokens_field.default
Expand Down Expand Up @@ -2279,6 +2286,27 @@ async def abort_cluster(self) -> JSONResponse:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

@staticmethod
def extract_guided_params(raw_body: dict) -> dict:
kwargs = {}
if raw_body.get("guided_json") is not None:
kwargs["guided_json"] = raw_body.get("guided_json")
if raw_body.get("guided_regex") is not None:
kwargs["guided_regex"] = raw_body.get("guided_regex")
if raw_body.get("guided_choice") is not None:
kwargs["guided_choice"] = raw_body.get("guided_choice")
if raw_body.get("guided_grammar") is not None:
kwargs["guided_grammar"] = raw_body.get("guided_grammar")
if raw_body.get("guided_json_object") is not None:
kwargs["guided_json_object"] = raw_body.get("guided_json_object")
if raw_body.get("guided_decoding_backend") is not None:
kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend")
if raw_body.get("guided_whitespace_pattern") is not None:
kwargs["guided_whitespace_pattern"] = raw_body.get(
"guided_whitespace_pattern"
)
return kwargs


def run(
supervisor_address: str,
Expand Down
85 changes: 83 additions & 2 deletions xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class VLLMModelConfig(TypedDict, total=False):
quantization: Optional[str]
max_model_len: Optional[int]
limit_mm_per_prompt: Optional[Dict[str, int]]
guided_decoding_backend: Optional[str]


class VLLMGenerateConfig(TypedDict, total=False):
Expand All @@ -85,6 +86,14 @@ class VLLMGenerateConfig(TypedDict, total=False):
stop: Optional[Union[str, List[str]]]
stream: bool # non-sampling param, should not be passed to the engine.
stream_options: Optional[Union[dict, None]]
response_format: Optional[dict]
guided_json: Optional[Union[str, dict]]
guided_regex: Optional[str]
guided_choice: Optional[List[str]]
guided_grammar: Optional[str]
guided_json_object: Optional[bool]
guided_decoding_backend: Optional[str]
guided_whitespace_pattern: Optional[str]


try:
Expand Down Expand Up @@ -314,6 +323,7 @@ def _sanitize_model_config(
model_config.setdefault("max_num_seqs", 256)
model_config.setdefault("quantization", None)
model_config.setdefault("max_model_len", None)
model_config.setdefault("guided_decoding_backend", "outlines")

return model_config

Expand All @@ -325,6 +335,22 @@ def _sanitize_generate_config(
generate_config = {}

sanitized = VLLMGenerateConfig()

response_format = generate_config.pop("response_format", None)
guided_decoding_backend = generate_config.get("guided_decoding_backend", None)
guided_json_object = None
guided_json = None

if response_format is not None:
if response_format.get("type") == "json_object":
guided_json_object = True
elif response_format.get("type") == "json_schema":
json_schema = response_format.get("json_schema")
assert json_schema is not None
guided_json = json_schema.get("json_schema")
if guided_decoding_backend is None:
guided_decoding_backend = "outlines"

sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
sanitized.setdefault("n", generate_config.get("n", 1))
sanitized.setdefault("best_of", generate_config.get("best_of", None))
Expand All @@ -346,6 +372,28 @@ def _sanitize_generate_config(
sanitized.setdefault(
"stream_options", generate_config.get("stream_options", None)
)
sanitized.setdefault(
"guided_json", generate_config.get("guided_json", guided_json)
)
sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None))
sanitized.setdefault(
"guided_choice", generate_config.get("guided_choice", None)
)
sanitized.setdefault(
"guided_grammar", generate_config.get("guided_grammar", None)
)
sanitized.setdefault(
"guided_whitespace_pattern",
generate_config.get("guided_whitespace_pattern", None),
)
sanitized.setdefault(
"guided_json_object",
generate_config.get("guided_json_object", guided_json_object),
)
sanitized.setdefault(
"guided_decoding_backend",
generate_config.get("guided_decoding_backend", guided_decoding_backend),
)

return sanitized

Expand Down Expand Up @@ -483,13 +531,46 @@ async def async_generate(
if isinstance(stream_options, dict)
else False
)
sampling_params = SamplingParams(**sanitized_generate_config)

if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
# guided decoding only available for vllm >= 0.6.3
from vllm.sampling_params import GuidedDecodingParams

guided_options = GuidedDecodingParams.from_optional(
json=sanitized_generate_config.pop("guided_json", None),
regex=sanitized_generate_config.pop("guided_regex", None),
choice=sanitized_generate_config.pop("guided_choice", None),
grammar=sanitized_generate_config.pop("guided_grammar", None),
json_object=sanitized_generate_config.pop("guided_json_object", None),
backend=sanitized_generate_config.pop("guided_decoding_backend", None),
whitespace_pattern=sanitized_generate_config.pop(
"guided_whitespace_pattern", None
),
)

sampling_params = SamplingParams(
guided_decoding=guided_options, **sanitized_generate_config
)
else:
# ignore generate configs
sanitized_generate_config.pop("guided_json", None)
sanitized_generate_config.pop("guided_regex", None)
sanitized_generate_config.pop("guided_choice", None)
sanitized_generate_config.pop("guided_grammar", None)
sanitized_generate_config.pop("guided_json_object", None)
sanitized_generate_config.pop("guided_decoding_backend", None)
sanitized_generate_config.pop("guided_whitespace_pattern", None)
sampling_params = SamplingParams(**sanitized_generate_config)

if not request_id:
request_id = str(uuid.uuid1())

assert self._engine is not None
results_generator = self._engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
prompt,
sampling_params,
request_id,
lora_request,
)

async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
Expand Down
Loading