Parse the response from the model properly and stream it to the client

This commit is contained in:
2025-01-24 00:53:49 +01:00
parent ec1e3fa810
commit 059202cd1e
2 changed files with 53 additions and 46 deletions
+4 -3
View File
@@ -13,7 +13,7 @@ function App() {
(async function() {
setCode('');
const response = await fetch(`${import.meta.env.VITE_API_URL}/generate?lang=ocaml&lines=20`);
const response = await fetch(`${import.meta.env.VITE_API_URL}/generate?lang=java&lines=50`);
const reader = response.body.getReader();
const decoder = new TextDecoder();
@@ -23,11 +23,12 @@ function App() {
setCode(prev => prev + decoder.decode(value));
if (done) {
setLoaded(true);
setCode(prev => prev.trim());
break;
}
}
setCode(prev => prev.trim());
setLoaded(true);
})();
}, []);
+45 -39
View File
@@ -1,7 +1,6 @@
package main
import (
"bytes"
"context"
"fmt"
"log"
@@ -19,54 +18,51 @@ import (
type CodeBlockParser struct {
channel chan []byte
processedChunks int
removedStartingBackticks bool
removedLanguageName bool
removedEndingBackticks bool
language string
maxLines int
currentLines int
}
func NewCodeBlockParser() *CodeBlockParser {
func NewCodeBlockParser(lang string, lines int) *CodeBlockParser {
return &CodeBlockParser{
channel: make(chan []byte),
processedChunks: 0,
removedStartingBackticks: false,
removedLanguageName: false,
removedEndingBackticks: false,
language: lang,
maxLines: lines,
currentLines: 0,
}
}
func (p *CodeBlockParser) ParseStream(chunk []byte, language string) {
if !p.removedStartingBackticks {
if bytes.Contains(chunk, []byte("```")) {
p.removedStartingBackticks = true
chunk = nil
}
func (parser *CodeBlockParser) ParseStream(chunk []byte) {
text := string(chunk)
if !strings.Contains(parser.language, text) && !strings.Contains(text, "```") && !strings.Contains(text, "``") {
if strings.Contains(text, "\n") {
parser.currentLines += 1
}
if !p.removedLanguageName && p.removedStartingBackticks && p.processedChunks <= 3 {
if strings.Contains(language, string(chunk)) {
chunk = nil
if parser.currentLines == parser.maxLines-1 {
indexOfNewLine := strings.Index(text, "\n")
if indexOfNewLine > -1 {
parser.channel <- []byte(text[:indexOfNewLine])
return
}
if string(chunk) == "\n" {
chunk = nil
p.removedLanguageName = true
}
parser.channel <- nil
return
}
if p.removedStartingBackticks && !p.removedEndingBackticks {
if bytes.Contains(chunk, []byte("```")) {
chunk = nil
p.removedEndingBackticks = true
}
if parser.currentLines >= parser.maxLines {
parser.channel <- nil
return
}
if p.removedEndingBackticks {
chunk = nil
parser.channel <- chunk
return
} else {
parser.channel <- nil
return
}
p.processedChunks += 1
p.channel <- chunk
}
func main() {
@@ -112,24 +108,34 @@ func main() {
log.Fatal(err)
}
parser := NewCodeBlockParser()
parser := NewCodeBlockParser(lang, lines)
ollamaCtx := context.Background()
prompt := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeHuman, fmt.Sprintf(`
You must only generate code without any descriptions or code comments or formatting like markdown code fences with backticks. Use spaces instead of tabs for spacing. Generate accurately according to the number of lines you get provided. Generate exactly between %d and %d lines of code from a well known open source project in the %s programming language.`, lines/2, lines, lang)),
You must only generate code without any descriptions or code comments and use spaces instead of tabs for spacing. Generate a maximum of %d lines of code from a well known open source project in the %s programming language.`, lines, lang)),
}
if _, err := llm.GenerateContent(ollamaCtx, prompt, llms.WithStreamingFunc(func(streamCtx context.Context, chunk []byte) error {
go parser.ParseStream(chunk, lang)
go parser.ParseStream(chunk)
select {
case cleaned := <-parser.channel:
if len(cleaned) > 0 {
ctx.Response().Write(cleaned)
case chunk := <-parser.channel:
if len(chunk) > 0 {
ctx.Response().Write(chunk)
ctx.Response().Flush()
}
if parser.currentLines == parser.maxLines {
cnx, _, err := ctx.Response().Hijack()
if err != nil {
log.Fatal(err)
return ctx.String(http.StatusInternalServerError, err.Error())
}
cnx.Close()
}
}
return nil
})); err != nil {
log.Fatal(err)