diff --git a/cmd/web/handlers.go b/cmd/web/handlers.go index 7debd6c..9d0c233 100644 --- a/cmd/web/handlers.go +++ b/cmd/web/handlers.go @@ -9,6 +9,7 @@ import ( "fmt" "math/rand" "net/http" + "slices" "time" ) @@ -140,22 +141,95 @@ func (app *application) createVotes(w http.ResponseWriter, r *http.Request) { app.serverError(w, r, err) } - // TODO check if he has voted - - //var voterIdentity string - if election.AreVotersKnown { - if request.VoterIdentity == nil || validator.Blank(*request.VoterIdentity) { - app.unprocessableEntityErrorSingle(w, fmt.Errorf("election has known voters; you must provide an identity provided by the organizer")) + for _, c := range request.Choices { + choiceExists := slices.Contains(election.Choices, c.ChoiceText) + if !choiceExists { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("choice %v doesn't exist", c.ChoiceText)) return } - //voterIdentity = *request.VoterIdentity - } else { - // TODO: get requester's IP address as identity } - // TODO verify if choice exists - // TODO count tokens to make sure user isn't trying to cheat + tokensUsed := 0 + for _, c := range request.Choices { + tokensUsed += c.Tokens + } - json, _ := json.Marshal(election) - w.Write(json) + if tokensUsed > election.Tokens { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("you used too many tokens; must not exceed %v tokens", election.Tokens)) + return + } + + electionHasExpired := election.ExpiresAt.Before(time.Now()) + if electionHasExpired { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("election has expired")) + return + } + + // this snippet of code also inserts in the `voters` table + voterIdentity := func() string { + var voterIdentity string + if election.AreVotersKnown { + if request.VoterIdentity == nil || validator.Blank(*request.VoterIdentity) { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("election has known voters; you must provide an identity provided by the organizer")) + return "" + } + + voterIdentity = *request.VoterIdentity + hasCastVotes, err := app.votes.Exists(voterIdentity, election.ID) + if err != nil { + app.serverError(w, r, err) + return "" + } + if hasCastVotes { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("you already voted")) + return "" + } + } else { + voterIdentity = r.RemoteAddr + + // if voters are known, voter will always exist + voterExists, err := app.voters.Exists(voterIdentity, election.ID) + if err != nil { + app.serverError(w, r, err) + return "" + } + if voterExists { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("you already voted")) + return "" + } + + _, err = app.voters.Insert(voterIdentity, election.ID) + if err != nil { + app.serverError(w, r, err) + return "" + } + } + return voterIdentity + }() + + if voterIdentity == "" { + return + } + + if !election.AreVotersKnown { + voterCount, err := app.voters.CountByElection(election.ID) + if err != nil && !errors.Is(sql.ErrNoRows, err) { + app.serverError(w, r, err) + return + } + // if voters are known, voterCount == election.MaxVoters in all cases + if voterCount == election.MaxVoters { + app.unprocessableEntityErrorSingle(w, fmt.Errorf("maximum voters reached")) + return + } + } + + for _, c := range request.Choices { + _, err := app.votes.Insert(voterIdentity, election.ID, c.ChoiceText, c.Tokens) + if err != nil { + app.serverError(w, r, err) + } + } + + w.WriteHeader(http.StatusCreated) } diff --git a/cmd/web/main.go b/cmd/web/main.go index 4491e6b..4ef94cf 100644 --- a/cmd/web/main.go +++ b/cmd/web/main.go @@ -39,6 +39,7 @@ func main() { logger: logger, elections: &models.ElectionModel{DB: db}, voters: &models.VoterModel{DB: db}, + votes: &models.VoteModel{DB: db}, } logger.Info("Starting server", "addr", addr) diff --git a/cmd/web/routes.go b/cmd/web/routes.go index 732e022..3dd6def 100644 --- a/cmd/web/routes.go +++ b/cmd/web/routes.go @@ -11,6 +11,7 @@ type application struct { logger *slog.Logger elections models.ElectionModelInterface voters models.VoterModelInterface + votes models.VoteModelInterface } func (app *application) routes() http.Handler { diff --git a/cmd/web/testutils_test.go b/cmd/web/testutils_test.go index 7e9e228..b5aebb4 100644 --- a/cmd/web/testutils_test.go +++ b/cmd/web/testutils_test.go @@ -15,6 +15,7 @@ func newTestApplication(t *testing.T) *application { logger: slog.New(slog.NewTextHandler(io.Discard, nil)), elections: &mockElectionModel{}, voters: &mockVoterModel{}, + votes: &mockVoteModel{}, } } @@ -56,8 +57,8 @@ func (e *mockElectionModel) GetById(id int) (*models.Election, error) { Tokens: 100, AreVotersKnown: false, MaxVoters: 10, - CreatedAt: time.Now().String(), - ExpiresAt: time.Now().Add(100 * time.Hour).String(), + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(100 * time.Hour), }, nil } @@ -67,3 +68,22 @@ type mockVoterModel struct { func (v *mockVoterModel) Insert(identity string, electionID int) (int, error) { return 1, nil } + +func (v *mockVoterModel) CountByElection(electionID int) (int, error) { + return 10, nil +} + +func (v *mockVoterModel) Exists(voterIdentity string, electionID int) (bool, error) { + return true, nil +} + +type mockVoteModel struct { +} + +func (v *mockVoteModel) Insert(voterIdentity string, electionId int, choiceText string, tokens int) (int, error) { + return 1, nil +} + +func (v *mockVoteModel) Exists(voterIdentity string, electionID int) (bool, error) { + return true, nil +} diff --git a/internal/models/elections.go b/internal/models/elections.go index a7591ed..81aaef1 100644 --- a/internal/models/elections.go +++ b/internal/models/elections.go @@ -20,8 +20,8 @@ type Election struct { Tokens int AreVotersKnown bool MaxVoters int - CreatedAt string - ExpiresAt string + CreatedAt time.Time + ExpiresAt time.Time Choices []string } diff --git a/internal/models/voters.go b/internal/models/voters.go index 4545ad8..6a760a0 100644 --- a/internal/models/voters.go +++ b/internal/models/voters.go @@ -2,10 +2,13 @@ package models import ( "database/sql" + "errors" ) type VoterModelInterface interface { Insert(identity string, electionID int) (int, error) + CountByElection(electionID int) (int, error) + Exists(voterIdentity string, electionID int) (bool, error) } type VoterModel struct { @@ -34,3 +37,53 @@ func (v *VoterModel) Insert(identity string, electionID int) (int, error) { voterId, err := result.LastInsertId() return int(voterId), nil } + +func (v *VoterModel) CountByElection(electionID int) (int, error) { + // use a transaction to prevent race conditions + tx, err := v.DB.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + + query := ` + SELECT COUNT(identity) + FROM voters + WHERE election_id = ? + GROUP BY election_id; + ` + + row := tx.QueryRow(query, electionID) + + var voterCount int + + err = row.Scan(&voterCount) + if err != nil { + return 0, err + } + + tx.Commit() + + return voterCount, nil +} + +func (v *VoterModel) Exists(voterIdentity string, electionID int) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 + FROM voters + WHERE identity = ? AND election_id = ? + ) + ` + + var exists bool + + err := v.DB.QueryRow(query, voterIdentity, electionID).Scan(&exists) + if err != nil { + if errors.Is(sql.ErrNoRows, err) { + return false, nil + } + return false, err + } + return exists, nil +} diff --git a/internal/models/votes.go b/internal/models/votes.go new file mode 100644 index 0000000..513d762 --- /dev/null +++ b/internal/models/votes.go @@ -0,0 +1,57 @@ +package models + +import ( + "database/sql" + "errors" +) + +type VoteModelInterface interface { + Insert(voterIdentity string, electionId int, choiceText string, tokens int) (int, error) + Exists(voterIdentity string, electionID int) (bool, error) +} + +type VoteModel struct { + DB *sql.DB +} + +func (v *VoteModel) Insert(voterIdentity string, electionId int, choiceText string, tokens int) (int, error) { + tx, err := v.DB.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + + result, err := tx.Exec(` + INSERT INTO votes (voter_identity, election_id, choice_text, tokens) + VALUES (?, ?, ?, ?)`, + voterIdentity, electionId, choiceText, tokens) + if err != nil { + return 0, err + } + + if err = tx.Commit(); err != nil { + return 0, err + } + + voteId, err := result.LastInsertId() + return int(voteId), nil +} + +func (v *VoteModel) Exists(voterIdentity string, electionID int) (bool, error) { + var exists bool + query := ` + SELECT EXISTS ( + SELECT 1 + FROM votes + WHERE voter_identity = ? AND election_id = ? + ) + ` + err := v.DB.QueryRow(query, voterIdentity, electionID).Scan(&exists) + if err != nil { + if errors.Is(sql.ErrNoRows, err) { + return false, nil + } + return false, err + } + return exists, nil +}