Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(progressbar): hopefully fix the progress bar updating #148

Merged
merged 1 commit into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The application allows users to interactively select models, sort, filter, edit,
- [go install (recommended)](#go-install-recommended)
- [curl](#curl)
- [Manually](#manually)
- [if "command not found: gollama"](#if-command-not-found-gollama)
- [Usage](#usage)
- [Key Bindings](#key-bindings)
- [Top](#top)
Expand Down Expand Up @@ -110,11 +111,11 @@ echo "alias g=gollama" >> ~/.zshrc
- `i`: Inspect model
- `t`: Top (show running models)
- `D`: Delete model
- `e`: Edit model **new**
- `e`: Edit model
- `c`: Copy model
- `U`: Unload all models
- `p`: Pull an existing model **new**
- `g`: Pull (get) new model **new**
- `p`: Pull an existing model
- `ctrl+p`: Pull (get) new model
- `P`: Push model
- `n`: Sort by name
- `s`: Sort by size
Expand Down Expand Up @@ -159,7 +160,7 @@ Note: Requires Admin privileges if you're running Windows.
- `-u`: Unload all running models
- `-v`: Print the version and exit
- `-h`, or `--host`: Specify the host for the Ollama API
- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API) **new**
- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API)
- `--vram`: Estimate vRAM usage for a model. Accepts:
- Ollama models (e.g. `llama3.1:8b-instruct-q6_K`, `qwen2:14b-q4_0`)
- HuggingFace models (e.g. `NousResearch/Hermes-2-Theta-Llama-3-8B`)
Expand Down
14 changes: 11 additions & 3 deletions app_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,17 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m.handlePullErrorMsg(msg)
case progressMsg:
if m.pullProgress < 1.0 {
m.pullProgress = msg.progress
return m, m.updateProgressCmd()
return m, tea.Batch(
m.updateProgressCmd(),
func() tea.Msg {
return progressMsg{
modelName: msg.modelName,
progress: m.pullProgress,
}
},
)
}
return m, nil
}
}
switch msg := msg.(type) {
Expand Down Expand Up @@ -412,7 +420,7 @@ func (m *AppModel) handlePullErrorMsg(msg pullErrorMsg) (tea.Model, tea.Cmd) {
}

func (m *AppModel) updateProgressCmd() tea.Cmd {
return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg {
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
return progressMsg{
modelName: m.pullInput.Value(),
progress: m.pullProgress,
Expand Down
2 changes: 1 addition & 1 deletion keymap.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewKeyMap() *KeyMap {
LinkModel: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "link (L=all)")),
PushModel: key.NewBinding(key.WithKeys("P"), key.WithHelp("P", "push")),
PullModel: key.NewBinding(key.WithKeys("p"), key.WithHelp("p", "pull")),
PullNewModel: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "get")),
PullNewModel: key.NewBinding(key.WithKeys("ctrl+p"), key.WithHelp("ctrl+p", "pull new model")),
Quit: key.NewBinding(key.WithKeys("q")),
RunModel: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "run")),
SortByFamily: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "^family")),
Expand Down
58 changes: 45 additions & 13 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,54 @@ func (m *AppModel) startPullModel(modelName string) tea.Cmd {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
if !m.pulling {
return context.Canceled
progressChan := make(chan float64)
errChan := make(chan error)

go func() {
req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
if !m.pulling {
return context.Canceled
}
progress := float64(resp.Completed) / float64(resp.Total)
m.pullProgress = progress
progressChan <- progress
return nil
})

if err == context.Canceled {
errChan <- fmt.Errorf("pull cancelled")
return
}
m.pullProgress = float64(resp.Completed) / float64(resp.Total)
return nil
})
if err != nil {
errChan <- err
return
}
close(progressChan)
}()

if err == context.Canceled {
return pullErrorMsg{fmt.Errorf("pull cancelled")}
}
if err != nil {
return pullErrorMsg{err}
// Start a ticker to send progress updates
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

for {
select {
case err := <-errChan:
if err != nil {
return pullErrorMsg{err}
}
return pullSuccessMsg{modelName}
case <-ticker.C:
return progressMsg{
modelName: modelName,
progress: m.pullProgress,
}
case progress := <-progressChan:
if progress >= 1.0 {
return pullSuccessMsg{modelName}
}
}
}
return pullSuccessMsg{modelName}
}
}

Expand Down
Loading