1package pg_store23// TODO: Add test for WriteSitePosts.4// TODO: Add test for TopPostsForSite.56import (7 "context"8 "flag"9 "fmt"10 "math/rand/v2"11 "net/netip"12 "net/url"13 "os"14 "testing"15 "time"1617 "website-feeds/model"18 "website-feeds/model/modeltest"1920 "github.com/jackc/pgx/v5/pgxpool"21)2223var pgaddr = flag.String(24 "database",25 os.Getenv("WEBSITE_FEEDS_TEST_DATABASE_URL"),26 "database address",27)2829func TestBasic(t *testing.T) {30 WithDatabase(t.Context(), t, func(t *testing.T, pool *pgxpool.Pool) {31 s, err := newFromPool(t.Context(), pool)32 if err != nil {33 t.Fatalf("unexpected error while creating db: %e", err)34 }35 err = s.updateSchema(t.Context())36 if err != nil {37 t.Fatalf("failed no-op schema update: %s", err)38 }39 s.Close()40 })41}4243func TestCreateSite(t *testing.T) {44 WithPgStore(t.Context(), t, func(t *testing.T, s *pgStore) {45 siteName := "TEST_SITE"46 siteDisplayName := "TestSiteButPretty"4748 site, err := s.CreateSite(t.Context(), siteName, siteDisplayName)49 if err != nil {50 t.Fatalf("unexpected error: %s", err)51 }5253 if site.DisplayName != siteDisplayName {54 t.Errorf(55 "wrong returned site display name: got=%s want=%s",56 site.DisplayName,57 siteDisplayName,58 )59 }6061 if site.Name != siteName {62 t.Errorf(63 "wrong returned site name: got=%s want=%s",64 site.Name,65 siteName,66 )67 }6869 if site.LastFetched != nil {70 t.Errorf("last_fetched should be nil. got:%v", site.LastFetched)71 }72 })73}7475func TestSettings(t *testing.T) {76 WithPgStore(t.Context(), t, func(t *testing.T, s *pgStore) {77 settings := &model.Settings{78 Reddit: &model.RedditSettings{79 ClientID: "REDDIT_CLIENT_ID",80 ClientSecret: "REDDIT_CLIENT_SECRET",81 },82 DefaultNumPosts: 111,83 MaxNumPosts: 222,84 SiteFetchWait: 333 * time.Second,85 }8687 err := s.UpdateSettings(t.Context(), settings)88 if err != nil {89 t.Fatalf("couldn't set settings: %s", err)90 }9192 dbSettings, err := s.Settings(t.Context())93 if err != nil {94 t.Fatalf("couldn't get first settings: %s", err)95 }9697 modeltest.CompareSettings(t, dbSettings, settings)9899 newSettings := &model.Settings{100 Reddit: nil,101 DefaultNumPosts: 1111,102 MaxNumPosts: 2222,103 SiteFetchWait: 3333 * time.Second,104 }105106 err = s.UpdateSettings(t.Context(), newSettings)107 if err != nil {108 t.Fatalf("couldn't set settings: %s", err)109 }110111 newDBSettings, err := s.Settings(t.Context())112 if err != nil {113 t.Fatalf("couldn't get first settings: %s", err)114 }115116 modeltest.CompareSettings(t, newDBSettings, newSettings)117 })118}119120func TestNumFetchPosts(t *testing.T) {121 WithPgStore(t.Context(), t, func(t *testing.T, s *pgStore) {122 siteName := "TEST_SITE"123 siteDisplayName := "TestSiteButPretty"124 site, err := s.CreateSite(t.Context(), siteName, siteDisplayName)125 if err != nil {126 t.Fatalf("unexpected error: %s", err)127 }128129 if site.DisplayName != siteDisplayName {130 t.Errorf(131 "wrong returned site display name: got=%s want=%s",132 site.DisplayName,133 siteDisplayName,134 )135 }136137 if site.Name != siteName {138 t.Errorf(139 "wrong returned site name: got=%s want=%s",140 site.Name,141 siteName,142 )143 }144145 if site.LastFetched != nil {146 t.Errorf("last_fetched should be nil. got:%v", site.LastFetched)147 }148149 now := time.Now()150151 err = s.WriteRequest(t.Context(), &model.Request{152 ID: 1234567,153 Host: netip.AddrPort{},154 SiteId: site.ID,155 NumPosts: 100,156 DefaultNumPosts: false,157 Timestamp: now,158 })159 if err != nil {160 t.Fatalf("unexpected error: %s", err)161 }162 n, err := s.NumFetchPostsForSite(t.Context(), site.ID, now)163 if err != nil {164 t.Fatalf("unexpected error: %s", err)165 }166 if n != 100 {167 t.Fatalf("NumFetchPosts()=%d. want=100", n)168 }169170 // After another request, we would expect to still want 100 posts, since171 // that number is still within the window.172 now = now.Add(23 * time.Hour)173 err = s.WriteRequest(t.Context(), &model.Request{174 ID: 1234567,175 Host: netip.AddrPort{},176 SiteId: site.ID,177 NumPosts: 10,178 DefaultNumPosts: false,179 Timestamp: now,180 })181 if err != nil {182 t.Fatalf("unexpected error: %s", err)183 }184 n, err = s.NumFetchPostsForSite(t.Context(), site.ID, now)185 if err != nil {186 t.Fatalf("unexpected error: %s", err)187 }188 if n != 100 {189 t.Fatalf("NumFetchPosts()=%d. want=100", n)190 }191192 // After 24 hours, the 100 post request should have aged out.193 now = now.Add(1 * time.Hour)194 n, err = s.NumFetchPostsForSite(t.Context(), site.ID, now)195 if err != nil {196 t.Fatalf("unexpected error: %s", err)197 }198 if n != 10 {199 t.Fatalf("NumFetchPosts()=%d. want=10", n)200 }201 })202}203204func WithDatabase(205 ctx context.Context,206 t *testing.T,207 f func(*testing.T, *pgxpool.Pool),208) {209 if *pgaddr == "" {210 t.Fatal("Test database not defined")211 }212213 pgUrl, err := url.Parse(*pgaddr)214 if err != nil {215 t.Fatalf("faled to parse db URL: %s", err)216 }217218 parentPool, err := pgxpool.New(ctx, *pgaddr)219 if err != nil {220 t.Fatalf("failed to connect to postgres: %s", err)221 }222223 if err := parentPool.Ping(ctx); err != nil {224 t.Fatalf("unable to ping DB: %s", err)225 }226227 dbName := fmt.Sprintf("website_feeds_test_%d", rand.Int64())228229 _, err = parentPool.Exec(t.Context(), `CREATE DATABASE `+dbName)230 if err != nil {231 t.Fatalf("failed to create database: %s", err)232 }233234 pgUrl.Path = dbName235 pool, err := pgxpool.New(ctx, pgUrl.String())236 if err != nil {237 t.Fatalf("failed to connect to postgres: %s", err)238 }239 defer pool.Close()240241 if err := pool.Ping(ctx); err != nil {242 t.Fatalf("unable to ping DB: %s", err)243 }244245 defer func() {246 _, err := parentPool.Exec(t.Context(), `DROP DATABASE `+dbName)247 if err != nil {248 t.Fatalf("failed to clean up database: %s", err)249 }250 }()251252 f(t, pool)253}254255func WithPgStore(256 ctx context.Context,257 t *testing.T,258 f func(*testing.T, *pgStore),259) {260 WithDatabase(ctx, t, func(t *testing.T, pool *pgxpool.Pool) {261 s, err := newFromPool(ctx, pool)262 if err != nil {263 t.Fatalf("unable to create pgStore: %s", err)264 }265266 defer s.Close()267268 f(t, s)269 })270}