Clean LLM output WIP

This commit is contained in:
2025-01-17 22:14:03 +01:00
parent 5c94a354d8
commit 03656acb6f
+49 -5
View File
@@ -1,6 +1,7 @@
package main
import (
"bytes"
"context"
"fmt"
"log"
@@ -15,6 +16,43 @@ import (
"github.com/tmc/langchaingo/llms/ollama"
)
type CodeBlockParser struct {
processed int
foundStart bool
foundEnd bool
}
func NewCodeBlockParser() *CodeBlockParser {
return &CodeBlockParser{
processed: 0,
foundStart: false,
foundEnd: false,
}
}
func (p *CodeBlockParser) ParseStream(chunk []byte) []byte {
if !p.foundStart {
if bytes.Contains(chunk, []byte("```")) {
p.foundStart = true
chunk = nil
}
}
if p.foundStart && p.processed == 1 {
chunk = nil
}
if p.foundStart && !p.foundEnd {
if bytes.Contains(chunk, []byte("```")) {
p.foundEnd = true
chunk = nil
}
}
p.processed += 1
return chunk
}
func main() {
err := godotenv.Load()
@@ -58,17 +96,23 @@ func main() {
log.Fatal(err)
}
// TODO: Clean the prompt result of any unnecessary formatting or text
parser := NewCodeBlockParser()
ollamaCtx := context.Background()
content := []llms.MessageContent{
prompt := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeSystem, `You must only generate code without any descriptions. Also don't include any code comments and use spaces instead of tabs for spacing. Most importantly. Most importantly you must remove any markdown code fences that wrap the content!`),
llms.TextParts(llms.ChatMessageTypeHuman, fmt.Sprintf(`
Generate a maximum of %d lines of code from a well known open source project in the %s programming language.`, lines, lang)),
}
if _, err := llm.GenerateContent(ollamaCtx, content, llms.WithStreamingFunc(func(streamCtx context.Context, chunk []byte) error {
ctx.Response().Write(chunk)
ctx.Response().Flush()
if _, err := llm.GenerateContent(ollamaCtx, prompt, llms.WithStreamingFunc(func(streamCtx context.Context, chunk []byte) error {
cleaned := parser.ParseStream(chunk)
if len(cleaned) > 0 {
fmt.Println(chunk, string(chunk))
ctx.Response().Write(cleaned)
ctx.Response().Flush()
}
return nil
})); err != nil {
log.Fatal(err)