Skip to content

Commit

Permalink
Merge pull request #4781 from hzhaoy/fix-dockerfile-cuda
Browse files Browse the repository at this point in the history
Fix cuda Dockerfile
  • Loading branch information
hiyouga authored Jul 13, 2024
2 parents 6b48308 + 8bab99c commit 5da54de
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 8 additions & 7 deletions docker/docker-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ FROM nvcr.io/nvidia/pytorch:24.02-py3
# Define environments
ENV MAX_JOBS=4
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn

# Define installation arguments
ARG INSTALL_BNB=false
Expand All @@ -23,13 +24,6 @@ RUN pip config set global.index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt

# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi

# Copy the rest of the application into the image
COPY . /app

Expand All @@ -46,6 +40,13 @@ RUN EXTRA_PACKAGES="metrics"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"

# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi

# Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]

Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", contr
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True)
Expand Down

0 comments on commit 5da54de

Please sign in to comment.