데이터 엔지니어링 정복/Algorithm

[완전탐색&백트래킹] 백준 - NM과 K

eeaarrtthh 2021. 9. 1. 17:37
728x90
반응형

https://www.acmicpc.net/problem/18290

 

18290번: NM과 K (1)

크기가 N×M인 격자판의 각 칸에 정수가 하나씩 들어있다. 이 격자판에서 칸 K개를 선택할 것이고, 선택한 칸에 들어있는 수를 모두 더한 값의 최댓값을 구하려고 한다. 단, 선택한 두 칸이 인접

www.acmicpc.net

 

-문제해설

일단 아이디어는 간단하다. 

 

main메서드에서 n m k값 입력받고 좌표값 입력받아서 arr 배열에 넣어준다.

그리고 visit 배열도 만들어준다. 이는 방문처리를 위한 배열이다.

 

findMax( int depth ) 메서드를 호출한다.

findMax 메서드에서는 좌표를 (0, 0)에서 (n-1, m-1) 순회한다.

순회하는 값을 행은 i, 열은 j라고 하자.

 

이때 i, j의 상하좌우를 방문처리해야 하므로 ok( int row, int col ) 메서드를 호출한다.

ok메서드에서는 row, col이 범위를 벗어나지 않는다면( 0~n / 0~m ) 해당 좌표를 방문처리한다.

( row, col )++

( row-1, col )++

( row+1, col )++

( row, col-1 )++

( row, col+1 )++

 

그리고 ans배열은 선택한 값을 집어넣는 배열이다.  방문처리되지 않은 값을 ans[depth] = arr[i][j] 를 통해서 집어넣는다.

그리고 findMax( depth+1 ) 재귀호출해준다.

depth == k가 되면 ans에 저장된 값을 모조리 더해준 뒤 기존 max값과 비교해서 큰 값으로 대체한다.

그리고 함수를 return해준다.

 

return으로 백트래킹되면 다시 visit 방문처리 값을 -1 감소시켜준다. 이는 comeBack( int row, int col )메서드가 담당한다.

ok메서드와 구조는 똑같고 ++가 --가 된다.

 

이렇게해서 전체 좌표에 대한 k개를 선택하여 최대값을 만드는 모든 경우를 탐색할 수 있다.

완전탐색이 끝났다면 max를 출력해주면 된다.

 

 

 

파이썬의 경우 자바와 똑같이 풀었더니 시간초과가 나와서 방문처리해주는 메서드의 중복연산을 피하기 위해

다른 분 풀이를 참고해서 코드를 변경했다.

 

 

-자바

package bruteforce;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class BJ18290 {
	static int n, m, k;
	static int[][] arr, visit; 
	static int[] ans; //선택할 값을 넣을 배열
	static int max = Integer.MIN_VALUE;
	
	public static void findMax( int depth ) {
		//depth값이 k이면 최대값 갱신
		if( depth == k ) {
			int tmp = 0;
			for( int i : ans ) tmp += i;
			max = Math.max( max, tmp );
			
		} else {
			for( int i=0; i<n; i++ ) {
				for( int j=0; j<m; j++ ) {
					if( visit[i][j] > 0 ) continue;
					
					ok( i, j );	//상하좌우 좌표에서 재귀호출되지 않도록 상하좌우의 check값을 증가시킨다.
					ans[depth] = arr[i][j];	//현재값을 ans배열에 넣는다.
					findMax( depth+1 );
					comeBack( i, j );	//상하좌우의 check값을 원상복구 시킨다.
				}
			}
			
		}
	}
	
	public static void ok( int row, int col ) {
		visit[row][col]++;  //현재값에 +1을 해준다.
		
		//row의 경우 0~n, col의 경우 0~m까지 범위를 벗어나지 않는지 검사
		//벗어나지 않으면 해당좌표의 visit값을 1씩 증가한다.
		//왼쪽
		if( col-1 >= 0 ) visit[row][col-1]++;
		//오른쪽
		if( col+1 < m ) visit[row][col+1]++;
		//위쪽
		if( row-1 >= 0 ) visit[row-1][col]++;
		//아래쪽
		if( row+1 < n ) visit[row+1][col]++;
	}
	
	public static void comeBack( int row, int col ) {
		//원상복구를 위해 현재값과 현재값의 상하좌우에 -1을 해준다.
		visit[row][col]--;  
		
		//왼쪽
		if( col-1 >= 0 ) visit[row][col-1]--;
		//오른쪽
		if( col+1 < m ) visit[row][col+1]--;
		//위쪽
		if( row-1 >= 0 ) visit[row-1][col]--;
		//아래쪽
		if( row+1 < n ) visit[row+1][col]--;
	}

	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer( br.readLine() );
		n = Integer.parseInt( st.nextToken() );
		m = Integer.parseInt( st.nextToken() );
		k = Integer.parseInt( st.nextToken() );
		ans = new int[k];
		arr = new int[n][m];
		visit = new int[n][m];
		
		for( int i=0; i<n; i++ ) {
			st = new StringTokenizer( br.readLine() );
			for( int j=0; j<m; j++ ) {
				arr[i][j] = Integer.parseInt( st.nextToken() );
			}
		}
		
		findMax( 0 );
		
		System.out.println( max );
	}
}

 

 

-파이썬 (pypy3로 해야지 통과)

from sys import stdin

def search( row, col, depth, hap ):
    global n, m, k, numMax
    if depth == k:
        if numMax < hap: numMax = hap
        return
    
    for i in range( row, n ):
        for j in range( col if i == row else 0, m ):
            #현재 위치 방문 여부 확인
            if visit[i][j]: continue
            
            ok = True
            #상하좌우 방문여부 확인
            for z in range( 4 ):
                nrow = i + drow[z]
                ncol = j + dcol[z]
                
                if 0<=nrow<n and 0<=ncol<m:
                    if visit[nrow][ncol]: ok = False
                    
            #방문하기
            if ok:
                visit[i][j] = True
                search( i, j, depth+1, hap+arr[i][j] )
                visit[i][j] = False

input = stdin.readline
n, m, k = map(int, input().split())
arr = [ list( map(int, input().split()) ) for _ in range( n ) ]
visit = [ [False]*m for _ in range( n ) ]
numMax = -10001
drow = [ 0, 0, 1, -1 ]
dcol = [ 1, -1, 0, 0 ]

search( 0, 0, 0, 0 )
print( numMax )

 

728x90
반응형