diff --git a/server/main.go b/server/main.go index ac563f8..c9df8a0 100644 --- a/server/main.go +++ b/server/main.go @@ -18,7 +18,9 @@ import ( ) type CodeBlockParser struct { + channel chan []byte processedChunks int + processedLines int removedStartingBackticks bool removedLanguageName bool removedEndingBackticks bool @@ -26,15 +28,20 @@ type CodeBlockParser struct { func NewCodeBlockParser() *CodeBlockParser { return &CodeBlockParser{ + channel: make(chan []byte), processedChunks: 0, + processedLines: 0, removedStartingBackticks: false, removedLanguageName: false, removedEndingBackticks: false, } } -// TODO: Use channels to stream date at a specific byte size to make control restrictions properly. -func (p *CodeBlockParser) ParseStream(chunk []byte, language string) []byte { +func (p *CodeBlockParser) ParseStream(chunk []byte, language string) { + if strings.Contains(string(chunk), "\n") { + p.processedLines += 1 + } + if !p.removedStartingBackticks { if bytes.Contains(chunk, []byte("```")) { p.removedStartingBackticks = true @@ -42,7 +49,7 @@ func (p *CodeBlockParser) ParseStream(chunk []byte, language string) []byte { } } - if p.removedStartingBackticks && p.processedChunks <= 3 { + if !p.removedLanguageName && p.removedStartingBackticks && p.processedChunks <= 3 { if strings.Contains(language, string(chunk)) { chunk = nil } @@ -61,7 +68,7 @@ func (p *CodeBlockParser) ParseStream(chunk []byte, language string) []byte { } p.processedChunks += 1 - return chunk + p.channel <- chunk } func main() { @@ -116,11 +123,14 @@ func main() { } if _, err := llm.GenerateContent(ollamaCtx, prompt, llms.WithStreamingFunc(func(streamCtx context.Context, chunk []byte) error { - cleaned := parser.ParseStream(chunk, lang) + go parser.ParseStream(chunk, lang) - if len(cleaned) > 0 { - ctx.Response().Write(cleaned) - ctx.Response().Flush() + select { + case cleaned := <-parser.channel: + if len(cleaned) > 0 { + ctx.Response().Write(cleaned) + ctx.Response().Flush() + } } return nil @@ -129,6 +139,7 @@ func main() { return ctx.String(http.StatusInternalServerError, err.Error()) } + defer close(parser.channel) return nil })