Skip to content

Commit

Permalink
FEAT: Update Qwen2-VL-Model to support flash_attention_2 implementati…
Browse files Browse the repository at this point in the history
…on (#2289)

Co-authored-by: qinxuye <[email protected]>
  • Loading branch information
LaureatePoet and qinxuye authored Sep 13, 2024
1 parent 56de933 commit 4274507
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions xinference/model/llm/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import logging
import sys
import uuid
from typing import Iterator, List, Optional, Union

Expand Down Expand Up @@ -59,9 +61,19 @@ def load(self):
self.model_path, trust_remote_code=True
)
self._tokenizer = self._processor.tokenizer
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_path, device_map=device, trust_remote_code=True
).eval()
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
if flash_attn_installed:
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype="bfloat16",
device_map=device,
attn_implementation="flash_attention_2",
trust_remote_code=True,
).eval()
else:
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_path, device_map=device, trust_remote_code=True
).eval()

def _transform_messages(
self,
Expand Down Expand Up @@ -177,8 +189,18 @@ def _generate_stream(
"streamer": streamer,
**inputs,
}

thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
error = None

def model_generate():
try:
return self._model.generate(**gen_kwargs)
except Exception:
nonlocal error
error = sys.exc_info()
streamer.end()
raise

thread = Thread(target=model_generate)
thread.start()

completion_id = str(uuid.uuid1())
Expand All @@ -195,6 +217,10 @@ def _generate_stream(
has_content=True,
)

if error:
_, err, tb = error # type: ignore
raise err.with_traceback(tb)

yield generate_completion_chunk(
chunk_text=None,
finish_reason="stop",
Expand Down

0 comments on commit 4274507

Please sign in to comment.