[Java/자바] 프로그래머스 Lv2 - 행렬의 곱셈

문제 풀이

2차원 행렬 arr1과 arr2를 입력받아, arr1에 arr2를 곱한 결과를 반환하는 함수, solution을 완성해주세요.

 

제한 조건
  • 행렬 arr1, arr2의 행과 열의 길이는 2 이상 100 이하입니다.
  • 행렬 arr1, arr2의 원소는 -10 이상 20 이하인 자연수입니다.
  • 곱할 수 있는 배열만 주어집니다.

 

입출력 예
arr1 arr2 return
[[1, 4], [3, 2], [4, 1]] [[3, 3], [3, 3]] [[15, 15], [15, 15], [15, 15]]
[[2, 3, 2], [4, 2, 4], [3, 1, 4]] [[5, 4, 3], [2, 4, 1], [3, 1, 1]] [[22, 22, 11], [36, 28, 18], [29, 20, 14]]

Solution.java

이 문제를 풀기 위해선 행렬의 특징을 이해해야합니다.

행렬의 특징을 이해하기 쉬운 그림 같이 첨부합니다.

만약 arr1 행렬, arr2 행렬이 있을 때

"곱한 행렬의 크기는 arr1 행렬의 행 개수 + arr2 행렬의 열 개수"가 됩니다. (arr1 행렬의 행 개수 + arr2 행렬의 열 개수 = 곱한 행렬 행렬의 크기)

즉, 아래의 그림을 보면서 설명하면 3개의 행을 가진 arr1과 2개의 열을 가진 arr2의 곱으로 만들어지는 행렬은 3행 2열이 됩니다.

 

또한, arr1 행렬의 열과 arr2 행렬의 행의 개수가 같아야 한다는 조건을 가집니다.

 

통과하지 못한 코드
import java.util.Arrays;

public class 행렬의곱셈 {
    public static void main(String[] args) {
        int[][] arr1 = new int[][] {{1, 4}, {3, 2}, {4, 1}};
        int[][] arr2 = new int[][] {{3, 3}, {3, 3}};
        System.out.println(Arrays.deepToString(solution(arr1, arr2)));

        int[][] arr3 = new int[][] {{2, 3, 2}, {4, 2, 4}, {3, 1, 4}};
        int[][] arr4 = new int[][] {{5, 4}, {2, 4}, {3, 1}};
        System.out.println(Arrays.deepToString(solution(arr3, arr4)));
    }

    public static int[][] solution(int[][] arr1, int[][] arr2) {
        int[][] answer = new int[arr1.length][arr2[0].length];

        for (int i = 0; i < arr1.length; i++) {
            for (int j = 0; j < arr1[0].length; j++) {
                for (int k = 0; k < arr2[0].length; k++) {
                    answer[i][j] += arr1[i][k] * arr2[k][j];
                }
            }
        }
        return answer;
    }
}

프로그래머스에서 테스트 코드는 통과했지만 제출 시 런타임 오류로 무엇이 문제인지 난감했었습니다.

하지만 알아보니 두 배열의 열의 개수가 다를 때 index 범위 오류가 나더라구요.

그래서 이를 해결하기 위해 아래처럼 수정하였습니다.

 

통과한 코드
import java.util.Arrays;

public class 행렬의곱셈 {
    public static void main(String[] args) {
        int[][] arr1 = new int[][] {{1, 4}, {3, 2}, {4, 1}};
        int[][] arr2 = new int[][] {{3, 3}, {3, 3}};
        System.out.println(Arrays.deepToString(solution(arr1, arr2)));

        int[][] arr3 = new int[][] {{2, 3, 2}, {4, 2, 4}, {3, 1, 4}};
        int[][] arr4 = new int[][] {{5, 4}, {2, 4}, {3, 1}};
        System.out.println(Arrays.deepToString(solution(arr3, arr4)));
    }

    public static int[][] solution(int[][] arr1, int[][] arr2) {
        int[][] answer = new int[arr1.length][arr2[0].length];

        for (int i = 0; i < arr1.length; i++) {
            for (int j = 0; j < arr2[0].length; j++) {
                for (int k = 0; k < arr1[0].length; k++) {
                    answer[i][j] += arr1[i][k] * arr2[k][j];
                }
            }
        }
        return answer;
    }
}

위 코드와 다르게 반복문이 돌 때 행렬이 곱해지는 방식을 보면서

j 반복문의 크기를 arr2의 열의 개수로 바꾸고 k 반복문의 크기를 arr1의 열의 개수로 바꾸었더니 통과할 수 있었습니다.

 

통과하지 못했던 이유를 두 행렬의 값이 아래와 같을 경우의 출력 결과를 보시면 그 이유를 금방 이해하실 수 있을겁니다.

int[][] arr3 = new int[][] {{2, 3, 2}, {4, 2, 4}, {3, 1, 4}};
int[][] arr4 = new int[][] {{5, 4}, {2, 4}, {3, 1}};
  • 통과 못한 코드의 출력 결과 일부

 

  • 통과한 코드의 출력 결과 일부