diff --git a/client/src/App.jsx b/client/src/App.jsx index 8a59b55..c4c0481 100644 --- a/client/src/App.jsx +++ b/client/src/App.jsx @@ -14,7 +14,7 @@ function App() { (async function() { setCode(''); - const response = await fetch(`${import.meta.env.VITE_API_URL}/generate?lang=java&lines=15`); + const response = await fetch(`${import.meta.env.VITE_API_URL}/generate?lang=lua`); const reader = response.body.getReader(); const decoder = new TextDecoder(); diff --git a/server/cmd/main.go b/server/cmd/main.go new file mode 100644 index 0000000..4f13fde --- /dev/null +++ b/server/cmd/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/joho/godotenv" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "touch-programming.hazemkrimi.tech/internal/handlers" +) + +func main() { + err := godotenv.Load() + + if err != nil { + log.Fatal("Error loading environment!") + } + + PORT := os.Getenv("PORT") + + if len(PORT) == 0 { + PORT = "8080" + } + + ech := echo.New() + + ech.Use(middleware.CORS()) + ech.GET("/generate", handlers.Generate) + ech.Logger.Fatal(ech.Start(fmt.Sprintf(":%s", PORT))) +} diff --git a/server/go.mod b/server/go.mod index 52477bd..3261f7c 100644 --- a/server/go.mod +++ b/server/go.mod @@ -1,4 +1,4 @@ -module experiment/oss-code-generator-ai-powered +module touch-programming.hazemkrimi.tech go 1.23.4 diff --git a/server/internal/handlers/generate.go b/server/internal/handlers/generate.go new file mode 100644 index 0000000..81c8c78 --- /dev/null +++ b/server/internal/handlers/generate.go @@ -0,0 +1,142 @@ +package handlers + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "strconv" + "strings" + "sync" + + "github.com/labstack/echo/v4" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/ollama" +) + +type ErrorMaxLinesReached struct{} + +func (err *ErrorMaxLinesReached) Error() string { + return fmt.Sprintf("Reached maximum lines configured!") +} + +type CodeBlockParser struct { + wg sync.WaitGroup + channel chan []byte + language string + maxLines int + currentLines int + backticksRemoved int +} + +func NewCodeBlockParser(lang string, lines int) *CodeBlockParser { + return &CodeBlockParser{ + channel: make(chan []byte), + language: lang, + maxLines: lines, + currentLines: 0, + backticksRemoved: 0, + } +} + +func (parser *CodeBlockParser) ParseStream(chunk []byte, cancelCtx context.CancelFunc) { + defer parser.wg.Done() + + text := string(chunk) + + // TODO: After getting the project up and running optimize this to take into account for example open brackets or function ending statements that will be closed after the max lines gets reached. + if !strings.Contains(parser.language, text) && !strings.Contains(text, "```") && !strings.Contains(text, "``") { + if parser.currentLines == 0 && parser.backticksRemoved > 0 && text == "\n" { + parser.channel <- nil + return + } + + if strings.Contains(text, "\n") { + parser.currentLines++ + } + + if parser.currentLines == parser.maxLines-1 { + indexOfNewLine := strings.Index(text, "\n") + + if indexOfNewLine > -1 { + parser.channel <- []byte(text[:indexOfNewLine]) + } else { + parser.channel <- nil + } + } else if parser.currentLines >= parser.maxLines { + cancelCtx() + } else { + parser.channel <- chunk + } + } else { + parser.channel <- nil + parser.backticksRemoved++ + } +} + +func Generate(ctx echo.Context) error { + LLM_MODEL := os.Getenv("LLM_MODEL") + + if len(LLM_MODEL) == 0 { + return ctx.String(http.StatusInternalServerError, "No LLM model specified in environment!") + } + + MAX_LINES, err := strconv.Atoi(os.Getenv("MAX_LINES")) + + if err != nil { + return ctx.String(http.StatusInternalServerError, "Error setting max lines!") + } + + lang := ctx.QueryParam("lang") + + if len(lang) == 0 { + return ctx.String(http.StatusBadRequest, "Lang param is incorrect!") + } + + llm, err := ollama.New(ollama.WithModel(LLM_MODEL)) + + if err != nil { + log.Println(err) + return ctx.String(http.StatusInternalServerError, "Error initializing LLM!") + } + + parser := NewCodeBlockParser(lang, MAX_LINES) + ollamaCtx, cancelOllamaCtx := context.WithCancel(context.Background()) + prompt := []llms.MessageContent{ + llms.TextParts(llms.ChatMessageTypeHuman, fmt.Sprintf(` + You must only generate code without any text descriptions or code comments and use spaces instead of tabs for spacing. Generate a maximum of %d lines of code from a well known open source project in the %s programming language.`, MAX_LINES, lang)), + } + + ctx.Response().WriteHeader(http.StatusOK) + + _, err = llm.GenerateContent(ollamaCtx, prompt, llms.WithStreamingFunc(func(streamCtx context.Context, chunk []byte) error { + parser.wg.Add(1) + go parser.ParseStream(chunk, cancelOllamaCtx) + + select { + case chunk, ok := <-parser.channel: + if ok && len(chunk) > 0 { + ctx.Response().Write(chunk) + ctx.Response().Flush() + } + case <-ollamaCtx.Done(): + return &ErrorMaxLinesReached{} + } + + return nil + })) + + parser.wg.Wait() + defer close(parser.channel) + + if err != nil { + if _, ok := err.(*ErrorMaxLinesReached); ok { + return nil + } + + return ctx.String(http.StatusInternalServerError, "Error generating code!") + } + + return nil +} diff --git a/server/main.go b/server/main.go deleted file mode 100644 index 047d18d..0000000 --- a/server/main.go +++ /dev/null @@ -1,151 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log" - "net/http" - "os" - "strconv" - "strings" - - "github.com/joho/godotenv" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/ollama" -) - -type CodeBlockParser struct { - channel chan []byte - language string - maxLines int - currentLines int -} - -func NewCodeBlockParser(lang string, lines int) *CodeBlockParser { - return &CodeBlockParser{ - channel: make(chan []byte), - language: lang, - maxLines: lines, - currentLines: 0, - } -} - -func (parser *CodeBlockParser) ParseStream(chunk []byte) { - text := string(chunk) - - if !strings.Contains(parser.language, text) && !strings.Contains(text, "```") && !strings.Contains(text, "``") { - if strings.Contains(text, "\n") { - parser.currentLines += 1 - } - - if parser.currentLines == parser.maxLines-1 { - indexOfNewLine := strings.Index(text, "\n") - - if indexOfNewLine > -1 { - parser.channel <- []byte(text[:indexOfNewLine]) - return - } - - parser.channel <- nil - return - } - - if parser.currentLines >= parser.maxLines { - parser.channel <- nil - return - } - - parser.channel <- chunk - return - } else { - parser.channel <- nil - return - } -} - -func main() { - err := godotenv.Load() - - if err != nil { - log.Fatal("Error loading environment!") - } - - LLM_MODEL := os.Getenv("LLM_MODEL") - PORT := os.Getenv("PORT") - - if len(LLM_MODEL) == 0 { - log.Fatal("No LLM model specified in environment!") - } - - if len(PORT) == 0 { - PORT = "8080" - } - - ech := echo.New() - - ech.Use(middleware.CORS()) - ech.GET("/generate", func(ctx echo.Context) error { - // TODO: Make lines an environment variable and tweak it along with the prompt to get a suitable number of lines for the challenge to be fun and not too hard but still challenging. - lines, err := strconv.Atoi(ctx.QueryParam("lines")) - - if err != nil { - return ctx.String(http.StatusBadRequest, "Lines param is not provided or incorrect!") - } - - lang := ctx.QueryParam("lang") - - if lang == "" { - return ctx.String(http.StatusBadRequest, "Lang param is not provided or incorrect!") - } - - ctx.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - ctx.Response().WriteHeader(http.StatusOK) - - llm, err := ollama.New(ollama.WithModel(fmt.Sprintf("%s", LLM_MODEL))) - - if err != nil { - log.Fatal(err) - } - - parser := NewCodeBlockParser(lang, lines) - ollamaCtx := context.Background() - prompt := []llms.MessageContent{ - llms.TextParts(llms.ChatMessageTypeHuman, fmt.Sprintf(` - You must only generate code without any text descriptions or code comments and use spaces instead of tabs for spacing. Generate a maximum of %d lines of code from a well known open source project in the %s programming language.`, lines, lang)), - } - - if _, err := llm.GenerateContent(ollamaCtx, prompt, llms.WithStreamingFunc(func(streamCtx context.Context, chunk []byte) error { - go parser.ParseStream(chunk) - - select { - case chunk := <-parser.channel: - if len(chunk) > 0 { - ctx.Response().Write(chunk) - ctx.Response().Flush() - } - - if parser.currentLines == parser.maxLines { - cnx, _, err := ctx.Response().Hijack() - - if err != nil { - log.Fatal(err) - return ctx.String(http.StatusInternalServerError, err.Error()) - } - - cnx.Close() - } - } - return nil - })); err != nil { - log.Fatal(err) - return ctx.String(http.StatusInternalServerError, err.Error()) - } - - defer close(parser.channel) - return nil - }) - - ech.Logger.Fatal(ech.Start(fmt.Sprintf(":%s", PORT))) -}